Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add str method to datamodule #20301

Merged
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a011366
Add feature implementation to datamodule for str method
MrWhatZitToYaa Sep 25, 2024
137e7b5
Removed list / tuple case for datamodule str method
MrWhatZitToYaa Sep 25, 2024
efe0c3c
Added test cases for DataModule string function
MrWhatZitToYaa Sep 25, 2024
23326d7
Reverted accidental changes in DataModule
MrWhatZitToYaa Sep 25, 2024
122cf6d
Updated dataloader str method
MrWhatZitToYaa Nov 20, 2024
9acf680
Merge branch 'master' into feature/9947_dataloader-string
lantiga Nov 20, 2024
51a4901
Improvements to implementation of str method for datamodule
MrWhatZitToYaa Nov 21, 2024
d37dfb0
Merge branch 'master' into feature/9947_dataloader-string
lantiga Nov 25, 2024
1ce0f92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
20a2a0c
Implementing str method for datamodule
MrWhatZitToYaa Nov 25, 2024
e03aefb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2024
6d05cfe
Add string method to datamodule
MrWhatZitToYaa Nov 25, 2024
acd2f73
Implementing str mehtod for dataloader
MrWhatZitToYaa Dec 4, 2024
b08e1fe
Implementing str fucntion for datamodule
MrWhatZitToYaa Dec 4, 2024
21029d2
Implementing str method for datamodule
MrWhatZitToYaa Dec 5, 2024
330a88c
Finilized required adjustments for dataloader string proposal method
MrWhatZitToYaa Dec 5, 2024
dbaabaa
Merge branch 'master' into feature/9947_dataloader-string
MrWhatZitToYaa Dec 6, 2024
5bf252e
Merge branch 'master' into feature/9947_dataloader-string
MrWhatZitToYaa Dec 10, 2024
06e4f1c
Implementing str method
MrWhatZitToYaa Dec 10, 2024
63200a0
Merge branch 'master' into feature/9947_dataloader-string
lantiga Dec 10, 2024
33cda77
Update src/lightning/pytorch/core/datamodule.py
lantiga Dec 10, 2024
f492167
Update src/lightning/pytorch/core/datamodule.py
lantiga Dec 10, 2024
36fdb57
Merge branch 'master' into feature/9947_dataloader-string
lantiga Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion src/lightning/pytorch/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"""LightningDataModule for loading DataLoaders with ease."""

import inspect
from collections.abc import Iterable
import os
from collections.abc import Iterable, Sized
from typing import IO, Any, Optional, Union, cast

from lightning_utilities import apply_to_collection
Expand Down Expand Up @@ -244,3 +245,75 @@ def load_from_checkpoint(
**kwargs,
)
return cast(Self, loaded)

def __str__(self) -> str:
"""Return a string representation of the datasets that are set up.

Returns:
A string representation of the datasets that are setup.

"""

class dataset_info:
def __init__(self, available: bool, length: str) -> None:
self.available = available
self.length = length

def retrieve_dataset_info(loader: DataLoader) -> dataset_info:
"""Helper function to compute dataset information."""
dataset = loader.dataset
size: str = str(len(dataset)) if isinstance(dataset, Sized) else "NA"

return dataset_info(True, size)

def loader_info(
loader: Union[DataLoader, Iterable[DataLoader]],
) -> Union[dataset_info, Iterable[dataset_info]]:
"""Helper function to compute dataset information."""
return apply_to_collection(loader, DataLoader, retrieve_dataset_info)

def extract_loader_info(methods: list[tuple[str, str]]) -> dict:
"""Helper function to extract information for each dataloader method."""
info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {}
for loader_name, func_name in methods:
loader_callback = getattr(self, func_name, None)
lantiga marked this conversation as resolved.
Show resolved Hide resolved

try:
loader = loader_callback()
lantiga marked this conversation as resolved.
Show resolved Hide resolved
info[loader_name] = loader_info(loader)
except Exception:
info[loader_name] = dataset_info(False, "")

return info

def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str:
"""Helper function to format loader information."""
output = []
for loader_name, loader_info in info.items():
# Single dataset
if isinstance(loader_info, dataset_info):
loader_info_formatted = "None" if not loader_info.available else f"size={loader_info.length}"
# Iterable of datasets
else:
loader_info_formatted = " ; ".join(
"None" if not loader_info_i.available else f"{i}. size={loader_info_i.length}"
for i, loader_info_i in enumerate(loader_info, start=1)
)

output.append(f"{{{loader_name}: {loader_info_formatted}}}")

return os.linesep.join(output)

# Available dataloader methods
datamodule_loader_methods: list[tuple[str, str]] = [
("Train dataloader", "train_dataloader"),
("Validation dataloader", "val_dataloader"),
("Test dataloader", "test_dataloader"),
("Predict dataloader", "predict_dataloader"),
]

# Retrieve information for each dataloader method
dataloader_info = extract_loader_info(datamodule_loader_methods)
# Format the information
dataloader_str = format_loader_info(dataloader_info)
return dataloader_str
83 changes: 82 additions & 1 deletion src/lightning/pytorch/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -188,6 +189,86 @@ def predict_dataloader(self) -> DataLoader:
return DataLoader(self.random_predict)


class BoringDataModuleNoLen(LightningDataModule):
"""
.. warning:: This is meant for testing/debugging and is experimental.
"""

def __init__(self) -> None:
super().__init__()

def setup(self, stage: str) -> None:
if stage == "fit":
self.random_train = RandomIterableDataset(32, 512)

if stage in ("fit", "validate"):
self.random_val = RandomIterableDataset(32, 128)

if stage == "test":
self.random_test = RandomIterableDataset(32, 256)

if stage == "predict":
self.random_predict = RandomIterableDataset(32, 64)

def train_dataloader(self) -> DataLoader:
return DataLoader(self.random_train)

def val_dataloader(self) -> DataLoader:
return DataLoader(self.random_val)

def test_dataloader(self) -> DataLoader:
return DataLoader(self.random_test)

def predict_dataloader(self) -> DataLoader:
return DataLoader(self.random_predict)


class IterableBoringDataModule(LightningDataModule):
def __init__(self) -> None:
super().__init__()

def setup(self, stage: str) -> None:
if stage == "fit":
self.train_datasets = [
RandomDataset(4, 16),
RandomIterableDataset(4, 16),
]

if stage in ("fit", "validate"):
self.val_datasets = [
RandomDataset(4, 32),
RandomIterableDataset(4, 32),
]

if stage == "test":
self.test_datasets = [
RandomDataset(4, 64),
RandomIterableDataset(4, 64),
]

if stage == "predict":
self.predict_datasets = [
RandomDataset(4, 128),
RandomIterableDataset(4, 128),
]

def train_dataloader(self) -> Iterable[DataLoader]:
combined_train = apply_to_collection(self.train_datasets, Dataset, lambda x: DataLoader(x))
return combined_train

def val_dataloader(self) -> DataLoader:
combined_val = apply_to_collection(self.val_datasets, Dataset, lambda x: DataLoader(x))
return combined_val

def test_dataloader(self) -> DataLoader:
combined_test = apply_to_collection(self.test_datasets, Dataset, lambda x: DataLoader(x))
return combined_test

def predict_dataloader(self) -> DataLoader:
combined_predict = apply_to_collection(self.predict_datasets, Dataset, lambda x: DataLoader(x))
return combined_predict


class ManualOptimBoringModel(BoringModel):
"""
.. warning:: This is meant for testing/debugging and is experimental.
Expand Down
112 changes: 111 additions & 1 deletion tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from argparse import Namespace
from dataclasses import dataclass
Expand All @@ -22,7 +23,12 @@
import torch
from lightning.pytorch import LightningDataModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.demos.boring_classes import (
BoringDataModule,
BoringDataModuleNoLen,
BoringModel,
IterableBoringDataModule,
)
from lightning.pytorch.profilers.simple import SimpleProfiler
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities import AttributeDict
Expand Down Expand Up @@ -510,3 +516,107 @@ def prepare_data(self):
durations = profiler.recorded_durations[key]
assert len(durations) == 1
assert durations[0] > 0


def test_datamodule_string_not_available():
dm = BoringDataModule()

expected_output = (
f"{{Train dataloader: None}}{os.linesep}"
f"{{Validation dataloader: None}}{os.linesep}"
f"{{Test dataloader: None}}{os.linesep}"
f"{{Predict dataloader: None}}"
)
out = str(dm)

assert out == expected_output


def test_datamodule_string_fit_setup():
dm = BoringDataModule()
dm.setup(stage="fit")

expected_output = (
f"{{Train dataloader: size=64}}{os.linesep}"
f"{{Validation dataloader: size=64}}{os.linesep}"
f"{{Test dataloader: None}}{os.linesep}"
f"{{Predict dataloader: None}}"
)
output = str(dm)

assert expected_output == output


def test_datamodule_string_validation_setup():
dm = BoringDataModule()
dm.setup(stage="validate")

expected_output = (
f"{{Train dataloader: None}}{os.linesep}"
f"{{Validation dataloader: size=64}}{os.linesep}"
f"{{Test dataloader: None}}{os.linesep}"
f"{{Predict dataloader: None}}"
)
output = str(dm)

assert expected_output == output


def test_datamodule_string_test_setup():
dm = BoringDataModule()
dm.setup(stage="test")

expected_output = (
f"{{Train dataloader: None}}{os.linesep}"
f"{{Validation dataloader: None}}{os.linesep}"
f"{{Test dataloader: size=64}}{os.linesep}"
f"{{Predict dataloader: None}}"
)
output = str(dm)

assert expected_output == output


def test_datamodule_string_predict_setup():
dm = BoringDataModule()
dm.setup(stage="predict")

expected_output = (
f"{{Train dataloader: None}}{os.linesep}"
f"{{Validation dataloader: None}}{os.linesep}"
f"{{Test dataloader: None}}{os.linesep}"
f"{{Predict dataloader: size=64}}"
)
output = str(dm)

assert expected_output == output


def test_datamodule_string_no_len():
dm = BoringDataModuleNoLen()
dm.setup("fit")

expected_output = (
f"{{Train dataloader: size=NA}}{os.linesep}"
f"{{Validation dataloader: size=NA}}{os.linesep}"
f"{{Test dataloader: None}}{os.linesep}"
f"{{Predict dataloader: None}}"
)
output = str(dm)

assert output == expected_output


def test_datamodule_string_iterable():
dm = IterableBoringDataModule()
dm.setup("fit")

expected_output = (
f"{{Train dataloader: 1. size=16 ; 2. size=NA}}{os.linesep}"
f"{{Validation dataloader: 1. size=32 ; 2. size=NA}}{os.linesep}"
f"{{Test dataloader: None}}{os.linesep}"
f"{{Predict dataloader: None}}"
)
output = str(dm)

assert output == expected_output
Loading