From a011366dda899b2fc24c3e2efd4658848e6ee0aa Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Tue, 24 Sep 2024 23:13:23 -0400 Subject: [PATCH 01/17] Add feature implementation to datamodule for str method First implementation scetch --- src/lightning/pytorch/core/datamodule.py | 48 +++++++++++++++++++----- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 6cb8f79f09284..4e57da79691eb 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -16,18 +16,17 @@ import inspect from typing import IO, Any, Dict, Iterable, Optional, Union, cast +import pytorch_lightning as pl +from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning_utilities import apply_to_collection +from pytorch_lightning.core.hooks import DataHooks +from pytorch_lightning.core.mixins import HyperparametersMixin +from pytorch_lightning.core.saving import _load_from_checkpoint +from pytorch_lightning.utilities.model_helpers import _restricted_classmethod +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader, Dataset, IterableDataset from typing_extensions import Self -import lightning.pytorch as pl -from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH -from lightning.pytorch.core.hooks import DataHooks -from lightning.pytorch.core.mixins import HyperparametersMixin -from lightning.pytorch.core.saving import _load_from_checkpoint -from lightning.pytorch.utilities.model_helpers import _restricted_classmethod -from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS - class LightningDataModule(DataHooks, HyperparametersMixin): """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is @@ -35,9 +34,9 @@ class LightningDataModule(DataHooks, HyperparametersMixin): Example:: - import lightning as L + import lightning.pytorch as L import torch.utils.data as data - from lightning.pytorch.demos.boring_classes import RandomDataset + from pytorch_lightning.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): @@ -243,3 +242,32 @@ def load_from_checkpoint( **kwargs, ) return cast(Self, loaded) + + def __str__(self) -> str: + """Return a string representation of the datasets that are setup. + + Returns: + A string representation of the datasets that are setup. + + """ + datasets_info = [] + + for attr_name in dir(self): + attr = getattr(self, attr_name) + + # Get Dataset information + if isinstance(attr, Dataset): + if hasattr(attr, "__len__"): + datasets_info.append(f"{attr_name}, dataset size={len(attr)}") + else: + datasets_info.append(f"{attr_name}, dataset size=Unavailable") + elif isinstance(attr, (list, tuple)) and all(isinstance(item, Dataset) for item in attr): + if all(hasattr(item, "__len__") for item in attr): + datasets_info.append(f"{attr_name}, dataset size={[len(ds) for ds in attr]}") + else: + datasets_info.append(f"{attr_name}, dataset size=Unavailable") + + if not datasets_info: + return "No datasets are set up." + + return "\n".join(datasets_info) From 137e7b54ca51046cc062f1a3c0f3c9d0787999cc Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Wed, 25 Sep 2024 11:17:35 -0400 Subject: [PATCH 02/17] Removed list / tuple case for datamodule str method --- src/lightning/pytorch/core/datamodule.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 4e57da79691eb..00d908c178f1b 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -261,11 +261,6 @@ def __str__(self) -> str: datasets_info.append(f"{attr_name}, dataset size={len(attr)}") else: datasets_info.append(f"{attr_name}, dataset size=Unavailable") - elif isinstance(attr, (list, tuple)) and all(isinstance(item, Dataset) for item in attr): - if all(hasattr(item, "__len__") for item in attr): - datasets_info.append(f"{attr_name}, dataset size={[len(ds) for ds in attr]}") - else: - datasets_info.append(f"{attr_name}, dataset size=Unavailable") if not datasets_info: return "No datasets are set up." From efe0c3c8956b8c9879d16aafea4d6f15afa7c839 Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Wed, 25 Sep 2024 13:45:30 -0400 Subject: [PATCH 03/17] Added test cases for DataModule string function Added alternative Boring Data Module implementations Added test cases for all possible options Added additional check for NotImplementedError in string function of DataModule --- src/lightning/pytorch/core/datamodule.py | 15 +++-- src/lightning/pytorch/demos/boring_classes.py | 32 ++++++++++ tests/tests_pytorch/core/test_datamodules.py | 63 ++++++++++++++++++- 3 files changed, 105 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 00d908c178f1b..f521fead755c8 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -252,17 +252,24 @@ def __str__(self) -> str: """ datasets_info = [] + def len_implemented(obj): + try: + len(obj) + return True + except NotImplementedError: + return False + for attr_name in dir(self): attr = getattr(self, attr_name) # Get Dataset information if isinstance(attr, Dataset): - if hasattr(attr, "__len__"): - datasets_info.append(f"{attr_name}, dataset size={len(attr)}") + if hasattr(attr, "__len__") and len_implemented(attr): + datasets_info.append(f"name={attr_name}, size={len(attr)}") else: - datasets_info.append(f"{attr_name}, dataset size=Unavailable") + datasets_info.append(f"name={attr_name}, size=Unavailable") if not datasets_info: return "No datasets are set up." - return "\n".join(datasets_info) + return "\n".join(datasets_info) + "\n" diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index fd2660228146e..637a1e131809a 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -187,6 +187,38 @@ def predict_dataloader(self) -> DataLoader: return DataLoader(self.random_predict) +class BoringDataModuleNoLen(LightningDataModule): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self) -> None: + super().__init__() + self.random_full = RandomIterableDataset(32, 64 * 4) + + +class BoringDataModuleLenNotImplemented(LightningDataModule): + """ + .. warning:: This is meant for testing/debugging and is experimental. + """ + + def __init__(self) -> None: + super().__init__() + + class DS(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index: int) -> Tensor: + return self.data[index] + + def __len__(self) -> int: + raise NotImplementedError + + self.random_full = DS(32, 64 * 4) + + class ManualOptimBoringModel(BoringModel): """ .. warning:: This is meant for testing/debugging and is experimental. diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 65fccb691a33d..0739359d6b5ae 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -22,7 +22,12 @@ import torch from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel +from lightning.pytorch.demos.boring_classes import ( + BoringDataModule, + BoringDataModuleLenNotImplemented, + BoringDataModuleNoLen, + BoringModel, +) from lightning.pytorch.profilers.simple import SimpleProfiler from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities import AttributeDict @@ -510,3 +515,59 @@ def prepare_data(self): durations = profiler.recorded_durations[key] assert len(durations) == 1 assert durations[0] > 0 + + +def test_datamodule_string_no_datasets(): + dm = BoringDataModule() + del dm.random_full + expected_output = "No datasets are set up." + assert str(dm) == expected_output + + +def test_datamodule_string_no_length(): + dm = BoringDataModuleNoLen() + expected_output = "name=random_full, size=Unavailable\n" + assert str(dm) == expected_output + + +def test_datamodule_string_length_not_implemented(): + dm = BoringDataModuleLenNotImplemented() + expected_output = "name=random_full, size=Unavailable\n" + assert str(dm) == expected_output + + +def test_datamodule_string_fit_setup(): + dm = BoringDataModule() + dm.setup(stage="fit") + + expected_outputs = ["name=random_full, size=256\n", "name=random_train, size=64\n", "name=random_val, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output + + +def test_datamodule_string_validation_setup(): + dm = BoringDataModule() + dm.setup(stage="validate") + expected_outputs = ["name=random_full, size=256\n", "name=random_val, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output + + +def test_datamodule_string_test_setup(): + dm = BoringDataModule() + dm.setup(stage="test") + expected_outputs = ["name=random_full, size=256\n", "name=random_test, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output + + +def test_datamodule_string_predict_setup(): + dm = BoringDataModule() + dm.setup(stage="predict") + expected_outputs = ["name=random_full, size=256\n", "name=random_predict, size=64\n"] + output = str(dm) + for expected_output in expected_outputs: + assert expected_output in output From 23326d7072cb689ace017b77e3c3a97af9846dae Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Wed, 25 Sep 2024 16:24:49 -0400 Subject: [PATCH 04/17] Reverted accidental changes in DataModule --- src/lightning/pytorch/core/datamodule.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index f521fead755c8..18bee4474e83d 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -16,17 +16,18 @@ import inspect from typing import IO, Any, Dict, Iterable, Optional, Union, cast -import pytorch_lightning as pl -from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning_utilities import apply_to_collection -from pytorch_lightning.core.hooks import DataHooks -from pytorch_lightning.core.mixins import HyperparametersMixin -from pytorch_lightning.core.saving import _load_from_checkpoint -from pytorch_lightning.utilities.model_helpers import _restricted_classmethod -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader, Dataset, IterableDataset from typing_extensions import Self +import lightning.pytorch as pl +from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH +from lightning.pytorch.core.hooks import DataHooks +from lightning.pytorch.core.mixins import HyperparametersMixin +from lightning.pytorch.core.saving import _load_from_checkpoint +from lightning.pytorch.utilities.model_helpers import _restricted_classmethod +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS + class LightningDataModule(DataHooks, HyperparametersMixin): """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is @@ -34,9 +35,9 @@ class LightningDataModule(DataHooks, HyperparametersMixin): Example:: - import lightning.pytorch as L + import lightning as L import torch.utils.data as data - from pytorch_lightning.demos.boring_classes import RandomDataset + from lightning.pytorch.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): From 122cf6d186197bedbc4ecb4271cbe1cb21f60ad6 Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Wed, 20 Nov 2024 15:37:25 -0500 Subject: [PATCH 05/17] Updated dataloader str method Made changes to comply with requested suggestions Switched from hardcoded \n to more general os.linesep --- src/lightning/pytorch/core/datamodule.py | 5 +-- tests/tests_pytorch/core/test_datamodules.py | 34 ++++++++++++-------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 18bee4474e83d..a79422fd8f942 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -14,6 +14,7 @@ """LightningDataModule for loading DataLoaders with ease.""" import inspect +import os from typing import IO, Any, Dict, Iterable, Optional, Union, cast from lightning_utilities import apply_to_collection @@ -251,7 +252,7 @@ def __str__(self) -> str: A string representation of the datasets that are setup. """ - datasets_info = [] + datasets_info: Optional[str] = [] def len_implemented(obj): try: @@ -273,4 +274,4 @@ def len_implemented(obj): if not datasets_info: return "No datasets are set up." - return "\n".join(datasets_info) + "\n" + return os.linesep.join(datasets_info) diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 0739359d6b5ae..2ed7a04a146dd 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle from argparse import Namespace from dataclasses import dataclass @@ -526,13 +527,13 @@ def test_datamodule_string_no_datasets(): def test_datamodule_string_no_length(): dm = BoringDataModuleNoLen() - expected_output = "name=random_full, size=Unavailable\n" + expected_output = "name=random_full, size=Unavailable" assert str(dm) == expected_output def test_datamodule_string_length_not_implemented(): dm = BoringDataModuleLenNotImplemented() - expected_output = "name=random_full, size=Unavailable\n" + expected_output = "name=random_full, size=Unavailable" assert str(dm) == expected_output @@ -540,34 +541,39 @@ def test_datamodule_string_fit_setup(): dm = BoringDataModule() dm.setup(stage="fit") - expected_outputs = ["name=random_full, size=256\n", "name=random_train, size=64\n", "name=random_val, size=64\n"] + expected_output = ( + f"name=random_full, size=256{os.linesep}" f"name=random_train, size=64{os.linesep}" f"name=random_val, size=64" + ) output = str(dm) - for expected_output in expected_outputs: - assert expected_output in output + + assert expected_output == output def test_datamodule_string_validation_setup(): dm = BoringDataModule() dm.setup(stage="validate") - expected_outputs = ["name=random_full, size=256\n", "name=random_val, size=64\n"] + + expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_val, size=64" output = str(dm) - for expected_output in expected_outputs: - assert expected_output in output + + assert expected_output == output def test_datamodule_string_test_setup(): dm = BoringDataModule() dm.setup(stage="test") - expected_outputs = ["name=random_full, size=256\n", "name=random_test, size=64\n"] + + expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_test, size=64" output = str(dm) - for expected_output in expected_outputs: - assert expected_output in output + + assert expected_output == output def test_datamodule_string_predict_setup(): dm = BoringDataModule() dm.setup(stage="predict") - expected_outputs = ["name=random_full, size=256\n", "name=random_predict, size=64\n"] + + expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_predict, size=64" output = str(dm) - for expected_output in expected_outputs: - assert expected_output in output + + assert expected_output == output From 51a49017e4ad0170adc2ade33b2bb945b0db6add Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Wed, 20 Nov 2024 21:08:46 -0500 Subject: [PATCH 06/17] Improvements to implementation of str method for datamodule Corrected the annotation for the internal function and the list that is suppsoed to store the information on the datasets --- src/lightning/pytorch/core/datamodule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index a79422fd8f942..197ed0a9cd81e 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -15,7 +15,7 @@ import inspect import os -from typing import IO, Any, Dict, Iterable, Optional, Union, cast +from typing import IO, Any, Dict, Iterable, List, Optional, Union, cast from lightning_utilities import apply_to_collection from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -252,9 +252,9 @@ def __str__(self) -> str: A string representation of the datasets that are setup. """ - datasets_info: Optional[str] = [] + datasets_info: Optional[List[str]] = [] - def len_implemented(obj): + def len_implemented(obj: Dataset) -> bool: try: len(obj) return True From 1ce0f920e23e4dc0656f1258b7c7b2587cde92e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 09:09:15 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 58e86c19cb3a2..079568b625910 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -13,9 +13,9 @@ # limitations under the License. """LightningDataModule for loading DataLoaders with ease.""" -from collections.abc import Iterable import inspect import os +from collections.abc import Iterable from typing import IO, Any, Optional, Union, cast from lightning_utilities import apply_to_collection From 20a2a0c56e1c1948f17dc9c521f2e715b44fe1e0 Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Mon, 25 Nov 2024 13:50:36 -0500 Subject: [PATCH 08/17] Implementing str method for datamodule Fixed type annotation issue Reduced code size by using Sized object from abc library --- src/lightning/pytorch/core/datamodule.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 079568b625910..486cc2db177c7 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -15,7 +15,7 @@ import inspect import os -from collections.abc import Iterable +from collections.abc import Sized, Iterable from typing import IO, Any, Optional, Union, cast from lightning_utilities import apply_to_collection @@ -253,21 +253,14 @@ def __str__(self) -> str: A string representation of the datasets that are setup. """ - datasets_info: Optional[list[str]] = [] - - def len_implemented(obj: Dataset) -> bool: - try: - len(obj) - return True - except NotImplementedError: - return False + datasets_info: list[str] = [] for attr_name in dir(self): attr = getattr(self, attr_name) # Get Dataset information if isinstance(attr, Dataset): - if hasattr(attr, "__len__") and len_implemented(attr): + if isinstance(attr, Sized): datasets_info.append(f"name={attr_name}, size={len(attr)}") else: datasets_info.append(f"name={attr_name}, size=Unavailable") From e03aefba76b9c1e05a94bcb9afb66abf237616e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 18:55:14 +0000 Subject: [PATCH 09/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 486cc2db177c7..b0f85094a67a4 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -15,7 +15,7 @@ import inspect import os -from collections.abc import Sized, Iterable +from collections.abc import Iterable, Sized from typing import IO, Any, Optional, Union, cast from lightning_utilities import apply_to_collection From 6d05cfef05ff24ea4a190a0d16c1d0f0a06ca5cc Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Mon, 25 Nov 2024 18:11:48 -0500 Subject: [PATCH 10/17] Add string method to datamodule Switched from Dataset based implementation to Dataloader based implementation --- src/lightning/pytorch/core/datamodule.py | 69 ++++++++++++++++++------ 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index b0f85094a67a4..6038e7eae523a 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -27,6 +27,7 @@ from lightning.pytorch.core.hooks import DataHooks from lightning.pytorch.core.mixins import HyperparametersMixin from lightning.pytorch.core.saving import _load_from_checkpoint +from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import _restricted_classmethod from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -247,25 +248,61 @@ def load_from_checkpoint( return cast(Self, loaded) def __str__(self) -> str: - """Return a string representation of the datasets that are setup. + """Return a string representation of the datasets that are set up. Returns: A string representation of the datasets that are setup. """ - datasets_info: list[str] = [] - for attr_name in dir(self): - attr = getattr(self, attr_name) - - # Get Dataset information - if isinstance(attr, Dataset): - if isinstance(attr, Sized): - datasets_info.append(f"name={attr_name}, size={len(attr)}") - else: - datasets_info.append(f"name={attr_name}, size=Unavailable") - - if not datasets_info: - return "No datasets are set up." - - return os.linesep.join(datasets_info) + def dataset_info(loader: DataLoader) -> tuple[str, str]: + """Helper function to compute dataset information.""" + dataset = loader.dataset + size: str + size = str(len(dataset)) if isinstance(dataset, Sized) else "size: unknown" + + return str(dataset), size + + def loader_info(loader_instance: Union[DataLoader, Iterable[DataLoader]]) -> str: + """Helper function to compute dataset information.""" + return apply_to_collection(loader_instance, tuple[str, str], dataset_info) + + dataloader_methods: list[tuple[str, str]] = [ + ("Train dataset", "train_dataloader"), + ("Validation dataset", "val_dataloader"), + ("Test dataset", "test_dataloader"), + ("Prediction dataset", "predict_dataloader"), + ] + dataloader_info: dict[str, Union[tuple[str, str], Iterable[tuple[str, str]]]] = {} + + # Retrieve information for each dataloader method + for method_pair in dataloader_methods: + method_str, method_name = method_pair + loader_method = getattr(self, method_name, None) + + if loader_method and callable(loader_method): + try: + loader_instance = loader_method() + dataloader_info[method_str] = loader_info(loader_instance) + except MisconfigurationException: + dataloader_info[method_str] = f"{method_str}: not implemented" + except Exception as e: + dataloader_info[method_str] = f"{method_name}: error - {str(e)}" + else: + dataloader_info[method_str] = f"{method_name}: not callable" + + # Format the information + dataloader_str: str = "" + for method_str, method_info in dataloader_info.items(): + if isinstance(method_info, tuple[str, str]): + dataloader_str += f"{{{method_str}: " + dataloader_str += f"name={method_info[0]}, size={method_info[1]}" + dataloader_str += f"}}{os.linesep}" + else: + dataloader_str += f"{{{method_str}: " + for info in method_info: + dataloader_str += f"name={info[0]}, size={info[1]} ; " + dataloader_str = dataloader_str[:-3] + dataloader_str += f"}}{os.linesep}" + + return dataloader_str From acd2f73032a3e44518877c7b8c905f59253c206a Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Wed, 4 Dec 2024 13:27:48 -0500 Subject: [PATCH 11/17] Implementing str mehtod for dataloader Added missing size value to tuple in the error case instead of returning only a string --- src/lightning/pytorch/core/datamodule.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 6038e7eae523a..8971a8ac27eb8 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -27,7 +27,6 @@ from lightning.pytorch.core.hooks import DataHooks from lightning.pytorch.core.mixins import HyperparametersMixin from lightning.pytorch.core.saving import _load_from_checkpoint -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import _restricted_classmethod from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -267,38 +266,45 @@ def loader_info(loader_instance: Union[DataLoader, Iterable[DataLoader]]) -> str """Helper function to compute dataset information.""" return apply_to_collection(loader_instance, tuple[str, str], dataset_info) - dataloader_methods: list[tuple[str, str]] = [ + dataloader_methods: dict[str, str] = { ("Train dataset", "train_dataloader"), ("Validation dataset", "val_dataloader"), ("Test dataset", "test_dataloader"), ("Prediction dataset", "predict_dataloader"), - ] + } dataloader_info: dict[str, Union[tuple[str, str], Iterable[tuple[str, str]]]] = {} # Retrieve information for each dataloader method for method_pair in dataloader_methods: method_str, method_name = method_pair loader_method = getattr(self, method_name, None) + print("Method name: ", method_name) if loader_method and callable(loader_method): try: loader_instance = loader_method() dataloader_info[method_str] = loader_info(loader_instance) - except MisconfigurationException: - dataloader_info[method_str] = f"{method_str}: not implemented" - except Exception as e: - dataloader_info[method_str] = f"{method_name}: error - {str(e)}" + print("loader instance") + except Exception as _: + dataloader_info[method_str] = (f"{method_str}: not available", "size: unknown") + print("Misconfiguration") else: - dataloader_info[method_str] = f"{method_name}: not callable" + dataloader_info[method_str] = (f"{method_name}: not callable", "size: unknown") + + print() # Format the information dataloader_str: str = "" for method_str, method_info in dataloader_info.items(): - if isinstance(method_info, tuple[str, str]): + # Single data set + print("Method info: ", method_info) + print(type(method_info)) + if isinstance(method_info, tuple): dataloader_str += f"{{{method_str}: " dataloader_str += f"name={method_info[0]}, size={method_info[1]}" dataloader_str += f"}}{os.linesep}" else: + # Iterable of datasets dataloader_str += f"{{{method_str}: " for info in method_info: dataloader_str += f"name={info[0]}, size={info[1]} ; " From b08e1fe84bb60a0b749105398c5c42936cf3d76f Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Wed, 4 Dec 2024 16:55:03 -0500 Subject: [PATCH 12/17] Implementing str fucntion for datamodule Adjusted test to match the new implementation requirenemnts Added necessary BoringModules for tests Fixed bugs and annotation issues in the str method --- src/lightning/pytorch/core/datamodule.py | 39 ++++---- src/lightning/pytorch/demos/boring_classes.py | 79 ++++++++++++---- tests/tests_pytorch/core/test_datamodules.py | 89 +++++++++++++++---- 3 files changed, 149 insertions(+), 58 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 8971a8ac27eb8..645c1c7551ad3 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -257,57 +257,48 @@ def __str__(self) -> str: def dataset_info(loader: DataLoader) -> tuple[str, str]: """Helper function to compute dataset information.""" dataset = loader.dataset - size: str - size = str(len(dataset)) if isinstance(dataset, Sized) else "size: unknown" + size: str = str(len(dataset)) if isinstance(dataset, Sized) else "unknown" - return str(dataset), size + return "yes", size def loader_info(loader_instance: Union[DataLoader, Iterable[DataLoader]]) -> str: """Helper function to compute dataset information.""" - return apply_to_collection(loader_instance, tuple[str, str], dataset_info) + result = apply_to_collection(loader_instance, DataLoader, dataset_info) - dataloader_methods: dict[str, str] = { + return result + + dataloader_methods: list[tuple[str, str]] = [ ("Train dataset", "train_dataloader"), ("Validation dataset", "val_dataloader"), ("Test dataset", "test_dataloader"), ("Prediction dataset", "predict_dataloader"), - } + ] dataloader_info: dict[str, Union[tuple[str, str], Iterable[tuple[str, str]]]] = {} # Retrieve information for each dataloader method for method_pair in dataloader_methods: method_str, method_name = method_pair loader_method = getattr(self, method_name, None) - print("Method name: ", method_name) - - if loader_method and callable(loader_method): - try: - loader_instance = loader_method() - dataloader_info[method_str] = loader_info(loader_instance) - print("loader instance") - except Exception as _: - dataloader_info[method_str] = (f"{method_str}: not available", "size: unknown") - print("Misconfiguration") - else: - dataloader_info[method_str] = (f"{method_name}: not callable", "size: unknown") - print() + try: + loader_instance = loader_method() + dataloader_info[method_str] = loader_info(loader_instance) + except Exception: + dataloader_info[method_str] = ("no", "unknown") # Format the information dataloader_str: str = "" for method_str, method_info in dataloader_info.items(): # Single data set - print("Method info: ", method_info) - print(type(method_info)) if isinstance(method_info, tuple): dataloader_str += f"{{{method_str}: " - dataloader_str += f"name={method_info[0]}, size={method_info[1]}" + dataloader_str += f"available={method_info[0]}, size={method_info[1]}" dataloader_str += f"}}{os.linesep}" else: # Iterable of datasets dataloader_str += f"{{{method_str}: " - for info in method_info: - dataloader_str += f"name={info[0]}, size={info[1]} ; " + for i, info in enumerate(method_info, start=1): + dataloader_str += f"{i}. available={info[0]}, size={info[1]} ; " dataloader_str = dataloader_str[:-3] dataloader_str += f"}}{os.linesep}" diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index 2c0d5596f7d6e..3855f31898b81 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from typing import Any, Optional import torch import torch.nn as nn import torch.nn.functional as F +from lightning_utilities import apply_to_collection from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -195,29 +196,77 @@ class BoringDataModuleNoLen(LightningDataModule): def __init__(self) -> None: super().__init__() - self.random_full = RandomIterableDataset(32, 64 * 4) + def setup(self, stage: str) -> None: + if stage == "fit": + self.random_train = RandomIterableDataset(32, 512) -class BoringDataModuleLenNotImplemented(LightningDataModule): - """ - .. warning:: This is meant for testing/debugging and is experimental. - """ + if stage in ("fit", "validate"): + self.random_val = RandomIterableDataset(32, 128) + + if stage == "test": + self.random_test = RandomIterableDataset(32, 256) + + if stage == "predict": + self.random_predict = RandomIterableDataset(32, 64) + def train_dataloader(self) -> DataLoader: + return DataLoader(self.random_train) + + def val_dataloader(self) -> DataLoader: + return DataLoader(self.random_val) + + def test_dataloader(self) -> DataLoader: + return DataLoader(self.random_test) + + def predict_dataloader(self) -> DataLoader: + return DataLoader(self.random_predict) + + +class IterableBoringDataModule(LightningDataModule): def __init__(self) -> None: super().__init__() - class DS(Dataset): - def __init__(self, size: int, length: int): - self.len = length - self.data = torch.randn(length, size) + def setup(self, stage: str) -> None: + if stage == "fit": + self.train_datasets = [ + RandomDataset(4, 16), + RandomIterableDataset(4, 16), + ] - def __getitem__(self, index: int) -> Tensor: - return self.data[index] + if stage in ("fit", "validate"): + self.val_datasets = [ + RandomDataset(4, 32), + RandomIterableDataset(4, 32), + ] - def __len__(self) -> int: - raise NotImplementedError + if stage == "test": + self.test_datasets = [ + RandomDataset(4, 64), + RandomIterableDataset(4, 64), + ] - self.random_full = DS(32, 64 * 4) + if stage == "predict": + self.predict_datasets = [ + RandomDataset(4, 128), + RandomIterableDataset(4, 128), + ] + + def train_dataloader(self) -> Iterable[DataLoader]: + combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x)) + return combined_train + + def val_dataloader(self) -> DataLoader: + combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x)) + return combined_val + + def test_dataloader(self) -> DataLoader: + combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x)) + return combined_test + + def predict_dataloader(self) -> DataLoader: + combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x)) + return combined_predict class ManualOptimBoringModel(BoringModel): diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 90a810fb10969..d02b09b2b5268 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -25,9 +25,9 @@ from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import ( BoringDataModule, - BoringDataModuleLenNotImplemented, BoringDataModuleNoLen, BoringModel, + IterableBoringDataModule, ) from lightning.pytorch.profilers.simple import SimpleProfiler from lightning.pytorch.trainer.states import TrainerFn @@ -518,34 +518,40 @@ def prepare_data(self): assert durations[0] > 0 -def test_datamodule_string_no_datasets(): +# TODO: Remove last os.linesep +def test_datamodule_string_not_available(): dm = BoringDataModule() - del dm.random_full - expected_output = "No datasets are set up." - assert str(dm) == expected_output - - -def test_datamodule_string_no_length(): - dm = BoringDataModuleNoLen() - expected_output = "name=random_full, size=Unavailable" - assert str(dm) == expected_output + expected_output = ( + f"{{Train dataset: available=no, size=unknown}}{os.linesep}" + f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" + f"{{Test dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + ) + out = str(dm) -def test_datamodule_string_length_not_implemented(): - dm = BoringDataModuleLenNotImplemented() - expected_output = "name=random_full, size=Unavailable" - assert str(dm) == expected_output + assert out == expected_output +# TODO Remove prints def test_datamodule_string_fit_setup(): dm = BoringDataModule() dm.setup(stage="fit") expected_output = ( - f"name=random_full, size=256{os.linesep}" f"name=random_train, size=64{os.linesep}" f"name=random_val, size=64" + f"{{Train dataset: available=yes, size=64}}{os.linesep}" + f"{{Validation dataset: available=yes, size=64}}{os.linesep}" + f"{{Test dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" ) output = str(dm) + print() + print(repr(expected_output)) + print() + print(repr(output)) + print() + assert expected_output == output @@ -553,7 +559,12 @@ def test_datamodule_string_validation_setup(): dm = BoringDataModule() dm.setup(stage="validate") - expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_val, size=64" + expected_output = ( + f"{{Train dataset: available=no, size=unknown}}{os.linesep}" + f"{{Validation dataset: available=yes, size=64}}{os.linesep}" + f"{{Test dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + ) output = str(dm) assert expected_output == output @@ -563,7 +574,12 @@ def test_datamodule_string_test_setup(): dm = BoringDataModule() dm.setup(stage="test") - expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_test, size=64" + expected_output = ( + f"{{Train dataset: available=no, size=unknown}}{os.linesep}" + f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" + f"{{Test dataset: available=yes, size=64}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + ) output = str(dm) assert expected_output == output @@ -573,7 +589,42 @@ def test_datamodule_string_predict_setup(): dm = BoringDataModule() dm.setup(stage="predict") - expected_output = f"name=random_full, size=256{os.linesep}" f"name=random_predict, size=64" + expected_output = ( + f"{{Train dataset: available=no, size=unknown}}{os.linesep}" + f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" + f"{{Test dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=yes, size=64}}{os.linesep}" + ) output = str(dm) assert expected_output == output + + +def test_datamodule_string_no_len(): + dm = BoringDataModuleNoLen() + dm.setup("fit") + + expected_output = ( + f"{{Train dataset: available=yes, size=unknown}}{os.linesep}" + f"{{Validation dataset: available=yes, size=unknown}}{os.linesep}" + f"{{Test dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + ) + output = str(dm) + + assert output == expected_output + + +def test_datamodule_string_iterable(): + dm = IterableBoringDataModule() + dm.setup("fit") + + expected_output = ( + f"{{Train dataset: 1. available=yes, size=16 ; 2. available=yes, size=unknown}}{os.linesep}" + f"{{Validation dataset: 1. available=yes, size=32 ; 2. available=yes, size=unknown}}{os.linesep}" + f"{{Test dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + ) + output = str(dm) + + assert output == expected_output From 21029d2d4ab767e16fe24c7f61d0c607c1849618 Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Thu, 5 Dec 2024 12:37:00 -0500 Subject: [PATCH 13/17] Implementing str method for datamodule Refactored code and made it more readable by implementing more abstarct fucntion methods Adjusted tests Removed debug statements Removed TODO comments --- src/lightning/pytorch/core/datamodule.py | 78 ++++++++++++-------- tests/tests_pytorch/core/test_datamodules.py | 22 ++---- 2 files changed, 54 insertions(+), 46 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 645c1c7551ad3..8fa73a1baa6f6 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -254,52 +254,68 @@ def __str__(self) -> str: """ - def dataset_info(loader: DataLoader) -> tuple[str, str]: + class dataset_info: + def __init__(self, available: str, length: str) -> None: + self.available = available + self.length = length + + def retrieve_dataset_info(loader: DataLoader) -> dataset_info: """Helper function to compute dataset information.""" dataset = loader.dataset size: str = str(len(dataset)) if isinstance(dataset, Sized) else "unknown" + output = dataset_info("yes", size) + return output - return "yes", size - - def loader_info(loader_instance: Union[DataLoader, Iterable[DataLoader]]) -> str: + def loader_info( + loader_instance: Union[DataLoader, Iterable[DataLoader]], + ) -> Union[dataset_info, Iterable[dataset_info]]: """Helper function to compute dataset information.""" - result = apply_to_collection(loader_instance, DataLoader, dataset_info) + result = apply_to_collection(loader_instance, DataLoader, retrieve_dataset_info) return result + def extract_loader_info(methods: list[tuple[str, str]]) -> dict: + """Helper function to extract information for each dataloader method.""" + info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {} + for method_str, function_name in methods: + loader_method = getattr(self, function_name, None) + + try: + loader_instance = loader_method() + info[method_str] = loader_info(loader_instance) + except Exception: + info[method_str] = dataset_info("no", "unknown") + + return info + + def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str: + """Helper function to format loader information.""" + lines = [] + for method_str, method_info in info.items(): + # Single dataset + if isinstance(method_info, dataset_info): + data_info = f"{{{method_str}: available={method_info.available}, size={method_info.length}}}" + lines.append(data_info) + # Iterable of datasets + else: + itr_data_info = " ; ".join( + f"{i}. available={dataset.available}, size={dataset.length}" + for i, dataset in enumerate(method_info, start=1) + ) + lines.append(f"{{{method_str}: {itr_data_info}}}") + + return os.linesep.join(lines) + + # Available dataloader methods dataloader_methods: list[tuple[str, str]] = [ ("Train dataset", "train_dataloader"), ("Validation dataset", "val_dataloader"), ("Test dataset", "test_dataloader"), ("Prediction dataset", "predict_dataloader"), ] - dataloader_info: dict[str, Union[tuple[str, str], Iterable[tuple[str, str]]]] = {} # Retrieve information for each dataloader method - for method_pair in dataloader_methods: - method_str, method_name = method_pair - loader_method = getattr(self, method_name, None) - - try: - loader_instance = loader_method() - dataloader_info[method_str] = loader_info(loader_instance) - except Exception: - dataloader_info[method_str] = ("no", "unknown") - + dataloader_info = extract_loader_info(dataloader_methods) # Format the information - dataloader_str: str = "" - for method_str, method_info in dataloader_info.items(): - # Single data set - if isinstance(method_info, tuple): - dataloader_str += f"{{{method_str}: " - dataloader_str += f"available={method_info[0]}, size={method_info[1]}" - dataloader_str += f"}}{os.linesep}" - else: - # Iterable of datasets - dataloader_str += f"{{{method_str}: " - for i, info in enumerate(method_info, start=1): - dataloader_str += f"{i}. available={info[0]}, size={info[1]} ; " - dataloader_str = dataloader_str[:-3] - dataloader_str += f"}}{os.linesep}" - + dataloader_str = format_loader_info(dataloader_info) return dataloader_str diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index d02b09b2b5268..c1fc24e0d9df3 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -518,7 +518,6 @@ def prepare_data(self): assert durations[0] > 0 -# TODO: Remove last os.linesep def test_datamodule_string_not_available(): dm = BoringDataModule() @@ -526,14 +525,13 @@ def test_datamodule_string_not_available(): f"{{Train dataset: available=no, size=unknown}}{os.linesep}" f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}" ) out = str(dm) assert out == expected_output -# TODO Remove prints def test_datamodule_string_fit_setup(): dm = BoringDataModule() dm.setup(stage="fit") @@ -542,16 +540,10 @@ def test_datamodule_string_fit_setup(): f"{{Train dataset: available=yes, size=64}}{os.linesep}" f"{{Validation dataset: available=yes, size=64}}{os.linesep}" f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}" ) output = str(dm) - print() - print(repr(expected_output)) - print() - print(repr(output)) - print() - assert expected_output == output @@ -563,7 +555,7 @@ def test_datamodule_string_validation_setup(): f"{{Train dataset: available=no, size=unknown}}{os.linesep}" f"{{Validation dataset: available=yes, size=64}}{os.linesep}" f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}" ) output = str(dm) @@ -578,7 +570,7 @@ def test_datamodule_string_test_setup(): f"{{Train dataset: available=no, size=unknown}}{os.linesep}" f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" f"{{Test dataset: available=yes, size=64}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}" ) output = str(dm) @@ -593,7 +585,7 @@ def test_datamodule_string_predict_setup(): f"{{Train dataset: available=no, size=unknown}}{os.linesep}" f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=yes, size=64}}{os.linesep}" + f"{{Prediction dataset: available=yes, size=64}}" ) output = str(dm) @@ -608,7 +600,7 @@ def test_datamodule_string_no_len(): f"{{Train dataset: available=yes, size=unknown}}{os.linesep}" f"{{Validation dataset: available=yes, size=unknown}}{os.linesep}" f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}" ) output = str(dm) @@ -623,7 +615,7 @@ def test_datamodule_string_iterable(): f"{{Train dataset: 1. available=yes, size=16 ; 2. available=yes, size=unknown}}{os.linesep}" f"{{Validation dataset: 1. available=yes, size=32 ; 2. available=yes, size=unknown}}{os.linesep}" f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}{os.linesep}" + f"{{Prediction dataset: available=no, size=unknown}}" ) output = str(dm) From 330a88cb35d922462348fde314766ac68496e6c3 Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Thu, 5 Dec 2024 12:53:23 -0500 Subject: [PATCH 14/17] Finilized required adjustments for dataloader string proposal method Renamed varaibles to more sensible names to increase readability --- src/lightning/pytorch/core/datamodule.py | 44 +++++++++++------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 8fa73a1baa6f6..2c1658e7b75d3 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -263,51 +263,49 @@ def retrieve_dataset_info(loader: DataLoader) -> dataset_info: """Helper function to compute dataset information.""" dataset = loader.dataset size: str = str(len(dataset)) if isinstance(dataset, Sized) else "unknown" - output = dataset_info("yes", size) - return output + + return dataset_info("yes", size) def loader_info( - loader_instance: Union[DataLoader, Iterable[DataLoader]], + loader: Union[DataLoader, Iterable[DataLoader]], ) -> Union[dataset_info, Iterable[dataset_info]]: """Helper function to compute dataset information.""" - result = apply_to_collection(loader_instance, DataLoader, retrieve_dataset_info) - - return result + return apply_to_collection(loader, DataLoader, retrieve_dataset_info) def extract_loader_info(methods: list[tuple[str, str]]) -> dict: """Helper function to extract information for each dataloader method.""" info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {} - for method_str, function_name in methods: - loader_method = getattr(self, function_name, None) + for loader_name, func_name in methods: + loader_callback = getattr(self, func_name, None) try: - loader_instance = loader_method() - info[method_str] = loader_info(loader_instance) + loader = loader_callback() + info[loader_name] = loader_info(loader) except Exception: - info[method_str] = dataset_info("no", "unknown") + info[loader_name] = dataset_info("no", "unknown") return info def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str: """Helper function to format loader information.""" - lines = [] - for method_str, method_info in info.items(): + output = [] + for loader_name, loader_info in info.items(): # Single dataset - if isinstance(method_info, dataset_info): - data_info = f"{{{method_str}: available={method_info.available}, size={method_info.length}}}" - lines.append(data_info) + if isinstance(loader_info, dataset_info): + loader_info_formatted = f"available={loader_info.available}, size={loader_info.length}" # Iterable of datasets else: - itr_data_info = " ; ".join( - f"{i}. available={dataset.available}, size={dataset.length}" - for i, dataset in enumerate(method_info, start=1) + loader_info_formatted = " ; ".join( + f"{i}. available={loader_info_i.available}, size={loader_info_i.length}" + for i, loader_info_i in enumerate(loader_info, start=1) ) - lines.append(f"{{{method_str}: {itr_data_info}}}") - return os.linesep.join(lines) + output.append(f"{{{loader_name}: {loader_info_formatted}}}") + + return os.linesep.join(output) # Available dataloader methods - dataloader_methods: list[tuple[str, str]] = [ + datamodule_loader_methods: list[tuple[str, str]] = [ ("Train dataset", "train_dataloader"), ("Validation dataset", "val_dataloader"), ("Test dataset", "test_dataloader"), @@ -315,7 +313,7 @@ def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info ] # Retrieve information for each dataloader method - dataloader_info = extract_loader_info(dataloader_methods) + dataloader_info = extract_loader_info(datamodule_loader_methods) # Format the information dataloader_str = format_loader_info(dataloader_info) return dataloader_str From 06e4f1c371f315103853d2b7c0a51a4a76b85f0c Mon Sep 17 00:00:00 2001 From: Mr WhatZitTooYa Date: Tue, 10 Dec 2024 16:15:03 -0500 Subject: [PATCH 15/17] Implementing str method Switched name from dataset to dataloader Switched name Prediction to Predict removed available keyword and instead write None if not available Switched from unknown to NA --- src/lightning/pytorch/core/datamodule.py | 20 +++---- tests/tests_pytorch/core/test_datamodules.py | 56 ++++++++++---------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 2c1658e7b75d3..018e9fb16a740 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -255,16 +255,16 @@ def __str__(self) -> str: """ class dataset_info: - def __init__(self, available: str, length: str) -> None: + def __init__(self, available: bool, length: str) -> None: self.available = available self.length = length def retrieve_dataset_info(loader: DataLoader) -> dataset_info: """Helper function to compute dataset information.""" dataset = loader.dataset - size: str = str(len(dataset)) if isinstance(dataset, Sized) else "unknown" + size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA" - return dataset_info("yes", size) + return dataset_info(True, size) def loader_info( loader: Union[DataLoader, Iterable[DataLoader]], @@ -282,7 +282,7 @@ def extract_loader_info(methods: list[tuple[str, str]]) -> dict: loader = loader_callback() info[loader_name] = loader_info(loader) except Exception: - info[loader_name] = dataset_info("no", "unknown") + info[loader_name] = dataset_info(False, "") return info @@ -292,11 +292,11 @@ def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info for loader_name, loader_info in info.items(): # Single dataset if isinstance(loader_info, dataset_info): - loader_info_formatted = f"available={loader_info.available}, size={loader_info.length}" + loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}" # Iterable of datasets else: loader_info_formatted = " ; ".join( - f"{i}. available={loader_info_i.available}, size={loader_info_i.length}" + "None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}" for i, loader_info_i in enumerate(loader_info, start=1) ) @@ -306,10 +306,10 @@ def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info # Available dataloader methods datamodule_loader_methods: list[tuple[str, str]] = [ - ("Train dataset", "train_dataloader"), - ("Validation dataset", "val_dataloader"), - ("Test dataset", "test_dataloader"), - ("Prediction dataset", "predict_dataloader"), + ("Train dataloader", "train_dataloader"), + ("Validation dataloader", "val_dataloader"), + ("Test dataloader", "test_dataloader"), + ("Predict dataloader", "predict_dataloader"), ] # Retrieve information for each dataloader method diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index c1fc24e0d9df3..b3ccd88aae704 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -522,10 +522,10 @@ def test_datamodule_string_not_available(): dm = BoringDataModule() expected_output = ( - f"{{Train dataset: available=no, size=unknown}}{os.linesep}" - f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" - f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}" + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" ) out = str(dm) @@ -537,10 +537,10 @@ def test_datamodule_string_fit_setup(): dm.setup(stage="fit") expected_output = ( - f"{{Train dataset: available=yes, size=64}}{os.linesep}" - f"{{Validation dataset: available=yes, size=64}}{os.linesep}" - f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}" + f"{{Train dataloader: size=64}}{os.linesep}" + f"{{Validation dataloader: size=64}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" ) output = str(dm) @@ -552,10 +552,10 @@ def test_datamodule_string_validation_setup(): dm.setup(stage="validate") expected_output = ( - f"{{Train dataset: available=no, size=unknown}}{os.linesep}" - f"{{Validation dataset: available=yes, size=64}}{os.linesep}" - f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}" + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: size=64}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" ) output = str(dm) @@ -567,10 +567,10 @@ def test_datamodule_string_test_setup(): dm.setup(stage="test") expected_output = ( - f"{{Train dataset: available=no, size=unknown}}{os.linesep}" - f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" - f"{{Test dataset: available=yes, size=64}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}" + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: size=64}}{os.linesep}" + f"{{Predict dataloader: None}}" ) output = str(dm) @@ -582,10 +582,10 @@ def test_datamodule_string_predict_setup(): dm.setup(stage="predict") expected_output = ( - f"{{Train dataset: available=no, size=unknown}}{os.linesep}" - f"{{Validation dataset: available=no, size=unknown}}{os.linesep}" - f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=yes, size=64}}" + f"{{Train dataloader: None}}{os.linesep}" + f"{{Validation dataloader: None}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: size=64}}" ) output = str(dm) @@ -597,10 +597,10 @@ def test_datamodule_string_no_len(): dm.setup("fit") expected_output = ( - f"{{Train dataset: available=yes, size=unknown}}{os.linesep}" - f"{{Validation dataset: available=yes, size=unknown}}{os.linesep}" - f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}" + f"{{Train dataloader: size=NA}}{os.linesep}" + f"{{Validation dataloader: size=NA}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" ) output = str(dm) @@ -612,10 +612,10 @@ def test_datamodule_string_iterable(): dm.setup("fit") expected_output = ( - f"{{Train dataset: 1. available=yes, size=16 ; 2. available=yes, size=unknown}}{os.linesep}" - f"{{Validation dataset: 1. available=yes, size=32 ; 2. available=yes, size=unknown}}{os.linesep}" - f"{{Test dataset: available=no, size=unknown}}{os.linesep}" - f"{{Prediction dataset: available=no, size=unknown}}" + f"{{Train dataloader: 1. size=16 ; 2. size=NA}}{os.linesep}" + f"{{Validation dataloader: 1. size=32 ; 2. size=NA}}{os.linesep}" + f"{{Test dataloader: None}}{os.linesep}" + f"{{Predict dataloader: None}}" ) output = str(dm) From 33cda778496c560b672340d121bb72313c7a43e9 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 11 Dec 2024 00:24:45 +0100 Subject: [PATCH 16/17] Update src/lightning/pytorch/core/datamodule.py --- src/lightning/pytorch/core/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 018e9fb16a740..0b4eea968510a 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -279,7 +279,7 @@ def extract_loader_info(methods: list[tuple[str, str]]) -> dict: loader_callback = getattr(self, func_name, None) try: - loader = loader_callback() + loader = loader_method() # type: ignore info[loader_name] = loader_info(loader) except Exception: info[loader_name] = dataset_info(False, "") From f492167c70993ac6e8b6272036780e3f93759805 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Wed, 11 Dec 2024 00:24:51 +0100 Subject: [PATCH 17/17] Update src/lightning/pytorch/core/datamodule.py --- src/lightning/pytorch/core/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index 0b4eea968510a..ff84c2fd8b199 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -276,7 +276,7 @@ def extract_loader_info(methods: list[tuple[str, str]]) -> dict: """Helper function to extract information for each dataloader method.""" info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {} for loader_name, func_name in methods: - loader_callback = getattr(self, func_name, None) + loader_method = getattr(self, func_name, None) try: loader = loader_method() # type: ignore