Skip to content

Commit

Permalink
Cast to >=float32 tensor when passing scalar to self.log (#19046)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
(cherry picked from commit 1fcb4ae)
  • Loading branch information
MF-FOOM authored and Borda committed Dec 19, 2023
1 parent ce30fdb commit 1c7bf98
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue parsing the version from folders that don't include a version number in `TensorBoardLogger` and `CSVLogger` ([#18897](https://github.com/Lightning-AI/lightning/issues/18897))


- Fixed the tensor conversion in `self.log` to respect the default dtype ([#19046](https://github.com/Lightning-AI/lightning/issues/19046))


## [2.1.0] - 2023-10-11

### Added
Expand Down
7 changes: 6 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator
from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype
from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
Expand Down Expand Up @@ -626,7 +627,11 @@ def __check_allowed(v: Any, name: str, value: Any) -> None:
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")

def __to_tensor(self, value: Union[Tensor, numbers.Number], name: str) -> Tensor:
value = value.clone().detach() if isinstance(value, Tensor) else torch.tensor(value, device=self.device)
value = (
value.clone().detach()
if isinstance(value, Tensor)
else torch.tensor(value, device=self.device, dtype=_get_default_dtype())
)
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
Expand Down

0 comments on commit 1c7bf98

Please sign in to comment.