Skip to content

Commit

Permalink
Add @override for files in src/lightning/fabric/loggers (#19090)
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorPrins authored Dec 1, 2023
1 parent 32f5ddd commit 9cb6e8c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/lightning/fabric/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, List, Optional, Set, Union

from torch import Tensor
from typing_extensions import override

from lightning.fabric.loggers.logger import Logger, rank_zero_experiment
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
self._flush_logs_every_n_steps = flush_logs_every_n_steps

@property
@override
def name(self) -> str:
"""Gets the name of the experiment.
Expand All @@ -83,6 +85,7 @@ def name(self) -> str:
return self._name

@property
@override
def version(self) -> Union[int, str]:
"""Gets the version of the experiment.
Expand All @@ -95,11 +98,13 @@ def version(self) -> Union[int, str]:
return self._version

@property
@override
def root_dir(self) -> str:
"""Gets the save directory where the versioned CSV experiments are saved."""
return self._root_dir

@property
@override
def log_dir(self) -> str:
"""The log directory for this run.
Expand Down Expand Up @@ -128,10 +133,12 @@ def experiment(self) -> "_ExperimentWriter":
self._experiment = _ExperimentWriter(log_dir=self.log_dir)
return self._experiment

@override
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
raise NotImplementedError("The `CSVLogger` does not yet support logging hyperparameters.")

@override
@rank_zero_only
def log_metrics( # type: ignore[override]
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
Expand All @@ -143,11 +150,13 @@ def log_metrics( # type: ignore[override]
if (step + 1) % self._flush_logs_every_n_steps == 0:
self.save()

@override
@rank_zero_only
def save(self) -> None:
super().save()
self.experiment.save()

@override
@rank_zero_only
def finalize(self, status: str) -> None:
if self._experiment is None:
Expand Down
10 changes: 10 additions & 0 deletions src/lightning/fabric/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn import Module
from typing_extensions import override

from lightning.fabric.loggers.logger import Logger, rank_zero_experiment
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(
self._kwargs = kwargs

@property
@override
def name(self) -> str:
"""Get the name of the experiment.
Expand All @@ -119,6 +121,7 @@ def name(self) -> str:
return self._name

@property
@override
def version(self) -> Union[int, str]:
"""Get the experiment version.
Expand All @@ -131,6 +134,7 @@ def version(self) -> Union[int, str]:
return self._version

@property
@override
def root_dir(self) -> str:
"""Gets the save directory where the TensorBoard experiments are saved.
Expand All @@ -141,6 +145,7 @@ def root_dir(self) -> str:
return self._root_dir

@property
@override
def log_dir(self) -> str:
"""The directory for this run's tensorboard checkpoint.
Expand Down Expand Up @@ -191,6 +196,7 @@ def experiment(self) -> "SummaryWriter":
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
Expand All @@ -212,6 +218,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor."
) from ex

@override
@rank_zero_only
def log_hyperparams( # type: ignore[override]
self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -251,6 +258,7 @@ def log_hyperparams( # type: ignore[override]
writer.add_summary(ssi)
writer.add_summary(sei)

@override
@rank_zero_only
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
model_example_input = getattr(model, "example_input_array", None)
Expand Down Expand Up @@ -278,10 +286,12 @@ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None
else:
self.experiment.add_graph(model, input_array)

@override
@rank_zero_only
def save(self) -> None:
self.experiment.flush()

@override
@rank_zero_only
def finalize(self, status: str) -> None:
if self._experiment is not None:
Expand Down

0 comments on commit 9cb6e8c

Please sign in to comment.