diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 9fc952f1..4e60d89a 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -60,7 +60,8 @@ }, "containerEnv": { "SCRATCH": "/home/vscode/scratch", - "SLURM_TMPDIR": "/tmp" + "SLURM_TMPDIR": "/tmp", + "NETWORK_DIR": "/network" }, "mounts": [ // https://code.visualstudio.com/remote/advancedcontainers/add-local-file-mount diff --git a/project/__init__.py b/project/__init__.py index 99210e55..b71fe767 100644 --- a/project/__init__.py +++ b/project/__init__.py @@ -1,6 +1,7 @@ from .algorithms import Algorithm, ExampleAlgorithm, ManualGradientsExample, NoOp from .configs import Config -from .datamodules import ImageClassificationDataModule, VisionDataModule +from .datamodules import VisionDataModule +from .datamodules.image_classification.image_classification import ImageClassificationDataModule from .experiment import Experiment from .networks import FcNet diff --git a/project/algorithms/__init__.py b/project/algorithms/__init__.py index 5a0a3021..cd52e402 100644 --- a/project/algorithms/__init__.py +++ b/project/algorithms/__init__.py @@ -3,8 +3,7 @@ from project.algorithms.jax_algo import JaxAlgorithm from project.algorithms.no_op import NoOp -from .bases.algorithm import Algorithm -from .bases.image_classification import ImageClassificationAlgorithm +from .algorithm import Algorithm from .example_algo import ExampleAlgorithm from .manual_optimization_example import ManualGradientsExample @@ -26,6 +25,6 @@ __all__ = [ "Algorithm", "ExampleAlgorithm", - "ImageClassificationAlgorithm", "ManualGradientsExample", + "JaxAlgorithm", ] diff --git a/project/algorithms/algorithm.py b/project/algorithms/algorithm.py new file mode 100644 index 00000000..71e98dbc --- /dev/null +++ b/project/algorithms/algorithm.py @@ -0,0 +1,164 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import NotRequired, TypedDict + +import torch +from lightning import Callback, LightningModule, Trainer +from torch import Tensor +from typing_extensions import Generic, TypeVar # noqa + +from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, +) +from project.utils.types import PhaseStr, PyTree +from project.utils.types.protocols import DataModule, Module + + +class StepOutputDict(TypedDict, total=False): + """A dictionary that shows what an Algorithm can output from + `training/validation/test_step`.""" + + loss: NotRequired[Tensor | float] + """Optional loss tensor that can be returned by those methods.""" + + +BatchType = TypeVar("BatchType", bound=PyTree[torch.Tensor], contravariant=True) +# StepOutputType = TypeVar( +# "StepOutputType", bound=StepOutputDict | PyTree[torch.Tensor], covariant=True +# ) +StepOutputType = TypeVar( + "StepOutputType", + bound=torch.Tensor | StepOutputDict, + default=StepOutputDict, + covariant=True, +) + + +class Algorithm(LightningModule, ABC, Generic[BatchType, StepOutputType]): + """Base class for a learning algorithm. + + This is an extension of the LightningModule class from PyTorch Lightning, with some common + boilerplate code to keep the algorithm implementations as simple as possible. + + The networks themselves are created separately and passed as a constructor argument. This is + meant to make it easier to compare different learning algorithms on the same network + architecture. + """ + + @dataclass + class HParams: + """Hyper-parameters of the algorithm.""" + + def __init__( + self, + *, + datamodule: DataModule[BatchType] | None = None, + network: Module | None = None, + hp: HParams | None = None, + ): + super().__init__() + self.datamodule = datamodule + self.network = network + self.hp = hp or self.HParams() + # fix for `self.device` property which defaults to cpu. + self._device = None + + if isinstance(datamodule, ImageClassificationDataModule): + self.example_input_array = torch.zeros( + (datamodule.batch_size, *datamodule.dims), device=self.device + ) + + self.trainer: Trainer + + def training_step(self, batch: BatchType, batch_index: int) -> StepOutputType: + """Performs a training step.""" + return self.shared_step(batch=batch, batch_index=batch_index, phase="train") + + def validation_step(self, batch: BatchType, batch_index: int) -> StepOutputType: + """Performs a validation step.""" + return self.shared_step(batch=batch, batch_index=batch_index, phase="val") + + def test_step(self, batch: BatchType, batch_index: int) -> StepOutputType: + """Performs a test step.""" + return self.shared_step(batch=batch, batch_index=batch_index, phase="test") + + def shared_step(self, batch: BatchType, batch_index: int, phase: PhaseStr) -> StepOutputType: + """Performs a training/validation/test step. + + This must return a nested dictionary of tensors matching the `StepOutputType` typedict for + this algorithm. By default, + `loss` entry. This is so that the training of the model is easier to parallelize the + training across GPUs: + - the cross entropy loss gets calculated using the global batch size + - the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP) + """ + raise NotImplementedError + + @abstractmethod + def configure_optimizers(self): + # """Creates the optimizers and the learning rate schedulers."""' + # super().configure_optimizers() + ... + + def forward(self, x: Tensor) -> Tensor: + """Performs a forward pass. + + Feel free to overwrite this to do whatever you'd like. + """ + assert self.network is not None + return self.network(x) + + def training_step_end(self, step_output: StepOutputDict) -> StepOutputDict: + """Called with the results of each worker / replica's output. + + See the `training_step_end` of pytorch-lightning for more info. + """ + return self.shared_step_end(step_output, phase="train") + + def validation_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out: + return self.shared_step_end(step_output, phase="val") + + def test_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out: + return self.shared_step_end(step_output, phase="test") + + def shared_step_end[Out: torch.Tensor | StepOutputDict]( + self, step_output: Out, phase: PhaseStr + ) -> Out: + """This is a default implementation for `[train/validation/test]_step_end`. + + This does the following: + - Averages out the `loss` tensor if it was left unreduced. + - the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP) + """ + + if ( + isinstance(step_output, dict) + and isinstance((loss := step_output.get("loss")), torch.Tensor) + and loss.shape + ): + # Replace the loss with its mean. This is useful when automatic + # optimization is enabled, for example in the example algo, where each replica + # returns the un-reduced cross-entropy loss. Here we need to reduce it to a scalar. + # todo: find out if this was already logged, to not log it twice. + # self.log(f"{phase}/loss", loss.mean(), sync_dist=True) + return step_output | {"loss": loss.mean()} + + elif isinstance(step_output, torch.Tensor) and (loss := step_output).shape: + return loss.mean() + + # self.log(f"{phase}/loss", torch.as_tensor(loss).mean(), sync_dist=True) + return step_output + + def configure_callbacks(self) -> list[Callback]: + """Use this to add some callbacks that should always be included with the model.""" + return [] + + @property + def device(self) -> torch.device: + if self._device is None: + self._device = next((p.device for p in self.parameters()), torch.device("cpu")) + device = self._device + # make this more explicit to always include the index + if device.type == "cuda" and device.index is None: + return torch.device("cuda", index=torch.cuda.current_device()) + return device diff --git a/project/algorithms/bases/algorithm_test.py b/project/algorithms/algorithm_test.py similarity index 95% rename from project/algorithms/bases/algorithm_test.py rename to project/algorithms/algorithm_test.py index 8d9a69ef..4983b026 100644 --- a/project/algorithms/bases/algorithm_test.py +++ b/project/algorithms/algorithm_test.py @@ -8,11 +8,11 @@ from collections.abc import Callable, Sequence from logging import getLogger as get_logger from pathlib import Path -from typing import Any, ClassVar, Generic, Literal, TypeVar +from typing import Any, ClassVar, Literal import pytest import torch -from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning import LightningDataModule, LightningModule, Trainer from lightning.pytorch.utilities.types import STEP_OUTPUT from omegaconf import DictConfig from tensor_regression import TensorRegressionFixture @@ -20,9 +20,11 @@ from torch.utils.data import DataLoader from typing_extensions import ParamSpec +from project.algorithms.algorithm import Algorithm +from project.algorithms.callbacks.callback import Callback from project.configs import Config, cs from project.conftest import setup_hydra_for_tests_and_compose -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.datamodules.vision import VisionDataModule @@ -42,21 +44,15 @@ ) from project.utils.types.protocols import DataModule -from .algorithm import Algorithm - logger = get_logger(__name__) P = ParamSpec("P") -AlgorithmType = TypeVar("AlgorithmType", bound=Algorithm) - SKIP_OR_XFAIL = pytest.xfail if "-vvv" in sys.argv else pytest.skip """Either skips the test entirely (default) or tries to run it and expect it to fail (slower).""" -skip_test = pytest.mark.xfail if "-vvv" in sys.argv else pytest.mark.skip - -class AlgorithmTests(Generic[AlgorithmType]): +class AlgorithmTests[AlgorithmType: Algorithm]: """Unit tests for an algorithm class. The algorithm creation is parametrized with all the datasets and all the networks, but the @@ -112,7 +108,7 @@ def n_updates(self) -> int: ```python @pytest.fixture - def n_updates(self, datamodule_name: str, network_name: str) -> int: + def n_updates(seor an actual classlf, datamodule_name: str, network_name: str) -> int: if datamodule_name == "imagenet32": return 10 return 3 @@ -141,6 +137,7 @@ def test_loss_is_reproducible( ) def get_testing_callbacks(self) -> list[TestingCallback]: + """Callbacks to be used for unit tests.""" return [ AllParamsShouldHaveGradients(), ] @@ -207,7 +204,7 @@ def _train( accelerator=accelerator, default_root_dir=tmp_path, callbacks=testing_callbacks.copy(), # type: ignore - # NOTE: Would be nice to be able to enforce this, but DTP uses nn.MaxUnpool2d. + # NOTE: Would be nice to be able to enforce this in general, but some algos could be using nn.MaxUnpool2d. deterministic=True if can_use_deterministic_mode else "warn", **trainer_kwargs, ) @@ -293,7 +290,12 @@ def datamodule_name(self, request: pytest.FixtureRequest): if datamodule_name in default_marks_for_config_name: for marker in default_marks_for_config_name[datamodule_name]: request.applymarker(marker) - self._skip_if_unsupported("datamodule", datamodule_name, skip_or_xfail=SKIP_OR_XFAIL) + # todo: if _supported_datamodule_types contains a protocol, this will raise a TypeError. In + # this case, we actually will use `_supported_datamodule_types` with `isinstance` instead. + try: + self._skip_if_unsupported("datamodule", datamodule_name, skip_or_xfail=SKIP_OR_XFAIL) + except TypeError: + pass return datamodule_name @pytest.fixture(params=get_all_network_names(), scope="class") @@ -564,7 +566,12 @@ def on_train_batch_end( batch: tuple[Tensor, Tensor], batch_index: int, ) -> None: - assert self.metric in trainer.logged_metrics, (self.metric, trainer.logged_metrics.keys()) + if self.metric not in trainer.logged_metrics: + logger.warning( + f"Unable to get the metric {self.metric} from the logged metrics: " + f"{trainer.logged_metrics.keys()} at step {trainer.global_step}." + ) + return metric_value = trainer.logged_metrics[self.metric] assert isinstance(metric_value, Tensor) self.metrics.append(metric_value.detach().item()) @@ -633,7 +640,6 @@ class AllParamsShouldHaveGradients(GetGradientsCallback): def __init__(self, exceptions: Sequence[str] = ()) -> None: super().__init__() self.exceptions = exceptions - self.gradients: dict[str, Tensor] = {} def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -649,7 +655,7 @@ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule) -> Non def on_train_batch_end( self, trainer: Trainer, - pl_module: LightningModule, + pl_module: Algorithm, outputs: STEP_OUTPUT, batch: Any, batch_index: int, diff --git a/project/algorithms/bases/__init__.py b/project/algorithms/bases/__init__.py deleted file mode 100644 index c32c735a..00000000 --- a/project/algorithms/bases/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .algorithm import Algorithm -from .image_classification import ImageClassificationAlgorithm - -__all__ = ["Algorithm", "ImageClassificationAlgorithm"] diff --git a/project/algorithms/bases/algorithm.py b/project/algorithms/bases/algorithm.py deleted file mode 100644 index 5806c625..00000000 --- a/project/algorithms/bases/algorithm.py +++ /dev/null @@ -1,108 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, TypedDict - -import torch -from lightning import Callback, LightningModule, Trainer -from torch import Tensor, nn -from typing_extensions import Generic, TypeVar # noqa - -from project.utils.types import NestedMapping, PhaseStr -from project.utils.types.protocols import DataModule, Module -from project.utils.utils import get_device - - -class StepOutputDict(TypedDict, total=False): - """A dictionary that shows what an Algorithm should output from `training/val/test_step`.""" - - loss: Tensor | float - """Optional loss tensor that can be returned by those methods.""" - - log: dict[str, Tensor | Any] - """Optional dictionary of things to log at each step.""" - - -BatchType = TypeVar("BatchType", bound=Tensor | Sequence[Tensor] | NestedMapping[str, Tensor]) -StepOutputType = TypeVar("StepOutputType", bound=StepOutputDict, default=StepOutputDict) -NetworkType = TypeVar("NetworkType", bound=Module, default=nn.Module) - - -class Algorithm(LightningModule, ABC, Generic[BatchType, StepOutputType, NetworkType]): - """Base class for a learning algorithm. - - This is an extension of the LightningModule class from PyTorch Lightning, with some common - boilerplate code to keep the algorithm implementations as simple as possible. - - The networks themselves are created separately and passed as a constructor argument. This is - meant to make it easier to compare different learning algorithms on the same network - architecture. - """ - - @dataclass - class HParams: - """Hyper-parameters of the algorithm.""" - - def __init__( - self, - *, - datamodule: DataModule[BatchType] | None = None, - network: NetworkType | None = None, - hp: HParams | None = None, - ): - super().__init__() - self.datamodule = datamodule - if isinstance(network, torch.nn.Module): - # fix for `self.device` property which defaults to cpu. - self._device = get_device(network) - elif network and not isinstance(network, torch.nn.Module): - # todo: Should we automatically convert jax networks to torch in case the base class - # doesn't? - pass - self.network = network - self.hp = hp or self.HParams() - self.trainer: Trainer - - def training_step(self, batch: BatchType, batch_index: int) -> StepOutputType: - """Performs a training step.""" - return self.shared_step(batch=batch, batch_index=batch_index, phase="train") - - def validation_step(self, batch: BatchType, batch_index: int) -> StepOutputType: - """Performs a validation step.""" - return self.shared_step(batch=batch, batch_index=batch_index, phase="val") - - def test_step(self, batch: BatchType, batch_index: int) -> StepOutputType: - """Performs a test step.""" - return self.shared_step(batch=batch, batch_index=batch_index, phase="test") - - def shared_step(self, batch: BatchType, batch_index: int, phase: PhaseStr) -> StepOutputType: - """Performs a training/validation/test step. - - This must return a dictionary with at least the 'y' and 'logits' keys, and an optional - `loss` entry. This is so that the training of the model is easier to parallelize the - training across GPUs: - - the cross entropy loss gets calculated using the global batch size - - the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP) - """ - raise NotImplementedError - - @abstractmethod - def configure_optimizers(self): - # """Creates the optimizers and the learning rate schedulers."""' - # super().configure_optimizers() - ... - - def forward(self, x: Tensor) -> Tensor: - """Performs a forward pass. - - Feel free to overwrite this to do whatever you'd like. - """ - return self.network(x) - - def configure_callbacks(self) -> list[Callback]: - """Use this to add some callbacks that should always be included with the model.""" - if getattr(self.hp, "use_scheduler", False) and self.trainer and self.trainer.logger: - from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor - - return [LearningRateMonitor()] - return [] diff --git a/project/algorithms/bases/image_classification.py b/project/algorithms/bases/image_classification.py deleted file mode 100644 index 49841119..00000000 --- a/project/algorithms/bases/image_classification.py +++ /dev/null @@ -1,179 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Required - -import torch -from torch import Tensor -from torch.nn import functional as F -from torchmetrics.classification import MulticlassAccuracy - -from project.algorithms.bases.algorithm import Algorithm, StepOutputDict -from project.datamodules.image_classification import ( - ImageClassificationDataModule, -) -from project.utils.types import PhaseStr -from project.utils.types.protocols import Module - -# TODO: Remove this `log` dict, perhaps it's better to use self.log of the pl module instead? - - -class ClassificationOutputs(StepOutputDict): - """The dictionary format that is minimally required to be returned from - `training/val/test_step`.""" - - logits: Required[Tensor] - """The un-normalized logits.""" - - y: Required[Tensor] - """The class labels.""" - - -class ImageClassificationAlgorithm[ - BatchType: tuple[Tensor, Tensor], - NetworkType: Module[[Tensor], Tensor], - StepOutputType: ClassificationOutputs, -](Algorithm[BatchType, StepOutputType, NetworkType], ABC): - """Base class for a learning algorithm for image classification. - - This is an extension of the LightningModule class from PyTorch Lightning, with some common - boilerplate code to keep the algorithm implementations as simple as possible. - - The network can be created separately. This makes it easier to compare different algorithms on the same architecture (e.g. your method vs a baseline). - """ - - @dataclass - class HParams(Algorithm.HParams): - """Hyper-parameters of the algorithm.""" - - def __init__( - self, - datamodule: ImageClassificationDataModule[BatchType], - network: NetworkType, - hp: ImageClassificationAlgorithm.HParams | None = None, - ): - super().__init__(datamodule=datamodule, network=network, hp=hp) - self.datamodule: ImageClassificationDataModule - # NOTE: Setting this property allows PL to infer the shapes and number of params. - # TODO: Check if PL now moves the `example_input_array` to the right device automatically. - # If possible, we'd like to remove any reference to the device from the algorithm. - self.example_input_array = torch.zeros( - [datamodule.batch_size, *datamodule.dims], - device=self.device, - ) - num_classes: int = datamodule.num_classes - - # IDEA: Could use a dict of metrics from torchmetrics instead of just accuracy: - # self.supervised_metrics: dist[str, Metrics] - # NOTE: Need to have one per phase! Not 100% sure that I'm not forgetting a phase here. - self.train_accuracy = MulticlassAccuracy(num_classes=num_classes) - self.val_accuracy = MulticlassAccuracy(num_classes=num_classes) - self.test_accuracy = MulticlassAccuracy(num_classes=num_classes) - self.train_top5_accuracy = MulticlassAccuracy(num_classes=num_classes, top_k=5) - self.val_top5_accuracy = MulticlassAccuracy(num_classes=num_classes, top_k=5) - self.test_top5_accuracy = MulticlassAccuracy(num_classes=num_classes, top_k=5) - - def training_step( - self, batch: tuple[Tensor, Tensor], batch_index: int - ) -> ClassificationOutputs: - """Performs a training step.""" - return self.shared_step(batch=batch, batch_index=batch_index, phase="train") - - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_index: int - ) -> ClassificationOutputs: - """Performs a validation step.""" - return self.shared_step(batch=batch, batch_index=batch_index, phase="val") - - def test_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> ClassificationOutputs: - """Performs a test step.""" - return self.shared_step(batch=batch, batch_index=batch_index, phase="test") - - def predict_step(self, batch: Tensor, batch_index: int, dataloader_idx: int): - """Performs a prediction step.""" - return self.predict(batch) - - def predict(self, x: Tensor) -> Tensor: - """Predict the classification labels.""" - return self.network(x).argmax(-1) - - @abstractmethod - def shared_step( - self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr - ) -> ClassificationOutputs: - """Performs a training/validation/test step. - - This must return a dictionary with at least the 'y' and 'logits' keys, and an optional - `loss` entry. This is so that the training of the model is easier to parallelize the - training across GPUs: - - the cross entropy loss gets calculated using the global batch size - - the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP) - """ - - def training_step_end(self, step_output: ClassificationOutputs) -> ClassificationOutputs: - """Called with the results of each worker / replica's output. - - See the `training_step_end` of pytorch-lightning for more info. - """ - return self.shared_step_end(step_output, phase="train") - - def validation_step_end(self, step_output: ClassificationOutputs) -> ClassificationOutputs: - return self.shared_step_end(step_output, phase="val") - - def test_step_end(self, step_output: ClassificationOutputs) -> ClassificationOutputs: - return self.shared_step_end(step_output, phase="test") - - def shared_step_end( - self, step_output: ClassificationOutputs, phase: PhaseStr - ) -> ClassificationOutputs: - required_entries = ClassificationOutputs.__required_keys__ - if not isinstance(step_output, dict): - raise RuntimeError( - f"Expected the {phase} step method to output a dictionary with at least the " - f"{required_entries} keys, but got an output of type {type(step_output)} instead!" - ) - if not all(k in step_output for k in required_entries): - raise RuntimeError( - f"Expected all the following keys to be in the output of the {phase} step " - f"method: {required_entries}" - ) - logits = step_output["logits"] - y = step_output["y"] - - probs = torch.softmax(logits, -1) - - accuracy = getattr(self, f"{phase}_accuracy") - top5_accuracy = getattr(self, f"{phase}_top5_accuracy") - - assert isinstance(accuracy, MulticlassAccuracy) - assert isinstance(top5_accuracy, MulticlassAccuracy) - - # TODO: It's a bit confusing, not sure if this is the right way to use this: - accuracy(probs, y) - top5_accuracy(probs, y) - prog_bar = phase == "train" - - self.log(f"{phase}/accuracy", accuracy, prog_bar=prog_bar, sync_dist=True) - self.log(f"{phase}/top5_accuracy", top5_accuracy, prog_bar=prog_bar, sync_dist=True) - - if "cross_entropy" not in step_output: - # Add the cross entropy loss as a metric. - ce_loss = F.cross_entropy(logits.detach(), y, reduction="mean") - self.log(f"{phase}/cross_entropy", ce_loss, prog_bar=prog_bar, sync_dist=True) - - fused_output = step_output.copy() - loss: Tensor | float | None = step_output.get("loss", None) - - if isinstance(loss, Tensor) and loss.shape: - # Replace the loss with its mean. This is useful when automatic - # optimization is enabled, for example in the baseline (backprop), where each replica - # returns the un-reduced cross-entropy loss. Here we need to reduce it to a scalar. - fused_output["loss"] = loss.mean() - - if loss is not None: - self.log( - f"{phase}/loss", torch.as_tensor(loss).mean(), prog_bar=prog_bar, sync_dist=True - ) - - return fused_output diff --git a/project/algorithms/callbacks/callback.py b/project/algorithms/callbacks/callback.py index ae8f12be..8b0bb3ca 100644 --- a/project/algorithms/callbacks/callback.py +++ b/project/algorithms/callbacks/callback.py @@ -2,21 +2,35 @@ from logging import getLogger as get_logger from pathlib import Path -from typing import override +from typing import Literal, override +import torch from lightning import Trainer from lightning import pytorch as pl from typing_extensions import Generic # noqa -from project.algorithms.bases.algorithm import Algorithm, BatchType, StepOutputType -from project.utils.types import PhaseStr, StageStr +from project.algorithms.algorithm import Algorithm, BatchType, StepOutputDict, StepOutputType +from project.utils.types import PhaseStr, PyTree from project.utils.utils import get_log_dir logger = get_logger(__name__) -class Callback(pl.Callback, Generic[BatchType, StepOutputType]): - """Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class.""" +class Callback[BatchType: PyTree[torch.Tensor], StepOutputType: torch.Tensor | StepOutputDict]( + pl.Callback +): + """Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class. + + Adds the following typing information: + - The type of inputs that the algorithm takes + - The type of outputs that are returned by the algorithm's `[training/validation/test]_step` methods. + + Adds the following methods: + - `on_shared_batch_start`: called by `on_[train/validation/test]_batch_start` + - `on_shared_batch_end`: called by `on_[train/validation/test]_batch_end` + - `on_shared_epoch_start`: called by `on_[train/validation/test]_epoch_start` + - `on_shared_epoch_end`: called by `on_[train/validation/test]_epoch_end` + """ def __init__(self) -> None: super().__init__() @@ -24,9 +38,12 @@ def __init__(self) -> None: @override def setup( - self, trainer: pl.Trainer, pl_module: Algorithm[BatchType, StepOutputType], stage: StageStr + self, + trainer: pl.Trainer, + pl_module: Algorithm[BatchType, StepOutputType], + # todo: "tune" is mentioned in the docstring, is it still used? + stage: Literal["fit", "validate", "test", "predict", "tune"], ) -> None: - """Called when fit, validate, test, predict, or tune begins.""" self.log_dir = get_log_dir(trainer=trainer) def on_shared_batch_start( @@ -37,7 +54,11 @@ def on_shared_batch_start( batch_index: int, phase: PhaseStr, dataloader_idx: int | None = None, - ): ... + ): + """Shared hook, called by `on_[train/validation/test]_batch_start`. + + Use this if you want to do something at the start of batches in more than one phase. + """ def on_shared_batch_end( self, @@ -48,15 +69,33 @@ def on_shared_batch_end( batch_index: int, phase: PhaseStr, dataloader_idx: int | None = None, - ): ... + ): + """Shared hook, called by `on_[train/validation/test]_batch_end`. + + Use this if you want to do something at the end of batches in more than one phase. + """ def on_shared_epoch_start( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType], phase: PhaseStr - ) -> None: ... + self, + trainer: Trainer, + pl_module: Algorithm[BatchType, StepOutputType], + phase: PhaseStr, + ) -> None: + """Shared hook, called by `on_[train/validation/test]_epoch_start`. + + Use this if you want to do something at the start of epochs in more than one phase. + """ def on_shared_epoch_end( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType], phase: PhaseStr - ) -> None: ... + self, + trainer: Trainer, + pl_module: Algorithm[BatchType, StepOutputType], + phase: PhaseStr, + ) -> None: + """Shared hook, called by `on_[train/validation/test]_epoch_end`. + + Use this if you want to do something at the end of epochs in more than one phase. + """ @override def on_train_batch_end( @@ -70,7 +109,7 @@ def on_train_batch_end( super().on_train_batch_end( trainer=trainer, pl_module=pl_module, - outputs=outputs, # type: ignore + outputs=outputs, batch=batch, batch_idx=batch_index, ) diff --git a/project/algorithms/callbacks/classification_metrics.py b/project/algorithms/callbacks/classification_metrics.py index fbbad03b..c76d9db9 100644 --- a/project/algorithms/callbacks/classification_metrics.py +++ b/project/algorithms/callbacks/classification_metrics.py @@ -1,6 +1,6 @@ import warnings from logging import getLogger as get_logger -from typing import Any, Required +from typing import NotRequired, Required, TypedDict, override import torch import torchmetrics @@ -8,19 +8,21 @@ from torch import Tensor from torchmetrics.classification import MulticlassAccuracy -from project.algorithms.bases.algorithm import Algorithm, BatchType -from project.algorithms.bases.image_classification import StepOutputDict +from project.algorithms.algorithm import Algorithm, BatchType from project.algorithms.callbacks.callback import Callback -from project.utils.types import PhaseStr, StageStr +from project.utils.types import PhaseStr from project.utils.types.protocols import ClassificationDataModule logger = get_logger(__name__) -class ClassificationOutputs(StepOutputDict): +class ClassificationOutputs(TypedDict, total=False): """The dictionary format that is minimally required to be returned from `training/val/test_step` for classification algorithms.""" + loss: NotRequired[torch.Tensor | float] + """The loss at this step.""" + logits: Required[Tensor] """The un-normalized logits.""" @@ -78,11 +80,12 @@ def _set_metric(pl_module: LightningModule, name: str, metric: torchmetrics.Metr def _get_metric(pl_module: LightningModule, name: str): return getattr(pl_module, name) + @override def setup( self, trainer: Trainer, - pl_module: Algorithm[BatchType, ClassificationOutputs, Any], - stage: StageStr, + pl_module: Algorithm[BatchType, ClassificationOutputs], + stage: PhaseStr, ) -> None: if self.disabled: return @@ -101,10 +104,11 @@ def setup( num_classes = datamodule.num_classes self.add_metrics_to(pl_module, num_classes=num_classes) + @override def on_shared_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, ClassificationOutputs, Any], + pl_module: Algorithm[BatchType, ClassificationOutputs], outputs: ClassificationOutputs, batch: BatchType, batch_index: int, @@ -149,7 +153,6 @@ def on_shared_batch_end( accuracy(probs, y) top5_accuracy(probs, y) prog_bar = phase == "train" - pl_module.log(f"{phase}/accuracy", accuracy, prog_bar=prog_bar, sync_dist=True) pl_module.log(f"{phase}/top5_accuracy", top5_accuracy, prog_bar=prog_bar, sync_dist=True) diff --git a/project/algorithms/callbacks/samples_per_second.py b/project/algorithms/callbacks/samples_per_second.py index 6b6e2b12..187bd247 100644 --- a/project/algorithms/callbacks/samples_per_second.py +++ b/project/algorithms/callbacks/samples_per_second.py @@ -1,10 +1,11 @@ import time +from typing import override from lightning import LightningModule, Trainer -from torch import Tensor, nn +from torch import Tensor from torch.optim import Optimizer -from project.algorithms.bases.algorithm import Algorithm, BatchType, StepOutputDict +from project.algorithms.algorithm import Algorithm, BatchType, StepOutputDict from project.algorithms.callbacks.callback import Callback from project.utils.types import PhaseStr, is_sequence_of @@ -16,10 +17,11 @@ def __init__(self): self.last_update_time: dict[int, float | None] = {} self.num_optimizers: int | None = None + @override def on_shared_epoch_start( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputDict, nn.Module], + pl_module: Algorithm[BatchType, StepOutputDict], phase: PhaseStr, ) -> None: self.last_update_time.clear() @@ -31,10 +33,11 @@ def on_shared_epoch_start( else: self.num_optimizers = len(optimizer_or_optimizers) + @override def on_shared_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputDict, nn.Module], + pl_module: Algorithm[BatchType, StepOutputDict], outputs: StepOutputDict, batch: BatchType, batch_index: int, @@ -66,6 +69,7 @@ def on_shared_batch_end( # todo: support other kinds of batches self.last_step_times[phase] = now + @override def on_before_optimizer_step( self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int = 0 ) -> None: diff --git a/project/algorithms/bases/image_classification_test.py b/project/algorithms/classification_tests.py similarity index 69% rename from project/algorithms/bases/image_classification_test.py rename to project/algorithms/classification_tests.py index 5282de07..3833311a 100644 --- a/project/algorithms/bases/image_classification_test.py +++ b/project/algorithms/classification_tests.py @@ -1,33 +1,42 @@ from __future__ import annotations -import itertools from pathlib import Path -from typing import ClassVar, TypeVar +from typing import ClassVar import pytest import torch.testing from torch import Tensor, nn from torch.utils.data import DataLoader, TensorDataset -from project.algorithms.bases.algorithm_test import ( +from project.algorithms.algorithm import Algorithm +from project.algorithms.algorithm_test import ( AlgorithmTests, CheckBatchesAreTheSameAtEachStep, MetricShouldImprove, ) -from project.algorithms.bases.image_classification import ImageClassificationAlgorithm -from project.datamodules.image_classification import ( +from project.algorithms.callbacks.classification_metrics import ClassificationOutputs +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.types import DataModule +from project.utils.types.protocols import ( + ClassificationDataModule, +) + +# Having tuple[torch.Tensor, torch.Tensor] as the batch type isn't ideal. -ImageAlgorithmType = TypeVar("ImageAlgorithmType", bound=ImageClassificationAlgorithm) +class ClassificationAlgorithmTests[ + AlgorithmType: Algorithm[tuple[Tensor, Tensor], ClassificationOutputs] +](AlgorithmTests[AlgorithmType]): + """Test suite for (image) classification algorithms.""" -class ImageClassificationAlgorithmTests(AlgorithmTests[ImageAlgorithmType]): unsupported_datamodule_types: ClassVar[list[type[DataModule]]] = [] unsupported_network_types: ClassVar[list[type[nn.Module]]] = [] - _supported_datamodule_types: ClassVar[list[type[ImageClassificationDataModule]]] = [ - ImageClassificationDataModule + _supported_datamodule_types: ClassVar[list[type[ClassificationDataModule]]] = [ + # VisionDataModule, + ClassificationDataModule, # type: ignore (we actually support this case). + # ImageClassificationDataModule, ] metric_name: ClassVar[str] = "train/accuracy" @@ -36,7 +45,7 @@ class ImageClassificationAlgorithmTests(AlgorithmTests[ImageAlgorithmType]): def test_output_shapes( self, - algorithm: ImageAlgorithmType, + algorithm: AlgorithmType, training_batch: tuple[Tensor, Tensor], ): """Tests that the output of the algorithm has the correct shape.""" @@ -47,6 +56,7 @@ def test_output_shapes( else: y_pred = output assert isinstance(y_pred, Tensor) + assert isinstance(algorithm.datamodule, ClassificationDataModule) if y_pred.dtype.is_floating_point: # y_pred should be the logits. assert y_pred.shape == (y.shape[0], algorithm.datamodule.num_classes) @@ -68,8 +78,7 @@ def training_batch( @pytest.fixture(scope="class") def repeat_first_batch_dataloader( self, - # algorithm: ImageAlgorithmType, - datamodule: ImageClassificationDataModule, + training_batch: tuple[Tensor, Tensor], n_updates: int, ): """Returns a dataloader that yields a exactly the same batch over and over again. @@ -81,28 +90,21 @@ def repeat_first_batch_dataloader( """ # Doing this just in case the algorithm wraps the datamodule somehow. # dm = getattr(algorithm, "datamodule", datamodule) - dm = datamodule - dm.prepare_data() - dm.setup("fit") + assert len(training_batch) + dataset = TensorDataset(*training_batch) + # need `start` to be of the same type, and it's hard to make an empty TensorDataset. + n_batches_dataset = sum([dataset] * (n_updates - 1), start=dataset) - train_dataloader = dm.train_dataloader() - assert isinstance(train_dataloader, DataLoader) - batch = next(iter(train_dataloader)) - batches = list(itertools.repeat(batch, n_updates)) - n_batches_dataset = TensorDataset( - *(torch.concatenate([b[i] for b in batches]) for i in range(len(batches[0]))) - ) - train_dl = DataLoader( - n_batches_dataset, batch_size=train_dataloader.batch_size, shuffle=False - ) - torch.testing.assert_close(next(iter(train_dl)), batch) + batch_size = training_batch[0].shape[0] + train_dl = DataLoader(n_batches_dataset, batch_size=batch_size, shuffle=False) + torch.testing.assert_close(next(iter(train_dl)), training_batch) return train_dl @pytest.mark.slow @pytest.mark.timeout(10) def test_overfit_exact_same_training_batch( self, - algorithm: ImageAlgorithmType, + algorithm: AlgorithmType, repeat_first_batch_dataloader: DataLoader, accelerator: str, devices: list[int], @@ -110,6 +112,7 @@ def test_overfit_exact_same_training_batch( tmp_path: Path, ): """Perform `n_updates` training steps on exactly the same batch of training data.""" + testing_callbacks = self.get_testing_callbacks() + [ CheckBatchesAreTheSameAtEachStep(), MetricShouldImprove(metric=self.metric_name, lower_is_better=self.lower_is_better), diff --git a/project/algorithms/example_algo.py b/project/algorithms/example_algo.py index 92c6902b..9d1d424c 100644 --- a/project/algorithms/example_algo.py +++ b/project/algorithms/example_algo.py @@ -11,6 +11,7 @@ from logging import getLogger from typing import Any +import torch from hydra_zen import instantiate from lightning.pytorch.callbacks import Callback, EarlyStopping from torch import Tensor @@ -18,29 +19,29 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler -from project.algorithms.bases.image_classification import ( +from project.algorithms.algorithm import Algorithm +from project.algorithms.callbacks.classification_metrics import ( + ClassificationMetricsCallback, ClassificationOutputs, - ImageClassificationAlgorithm, ) from project.configs.algorithm.lr_scheduler import CosineAnnealingLRConfig from project.configs.algorithm.optimizer import AdamConfig -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.types import PhaseStr -from project.utils.types.protocols import Module logger = getLogger(__name__) -class ExampleAlgorithm(ImageClassificationAlgorithm): +class ExampleAlgorithm(Algorithm): """Example algorithm for image classification.""" # TODO: Make this less specific to Image classification once we add other supervised learning # settings. @dataclass - class HParams(ImageClassificationAlgorithm.HParams): + class HParams(Algorithm.HParams): """Hyper-Parameters of the baseline model.""" # Arguments to be passed to the LR scheduler. @@ -68,16 +69,23 @@ class HParams(ImageClassificationAlgorithm.HParams): def __init__( self, datamodule: ImageClassificationDataModule, - network: Module[[Tensor], Tensor], + network: torch.nn.Module, hp: ExampleAlgorithm.HParams | None = None, ): - super().__init__(datamodule=datamodule, network=network, hp=hp) - self.datamodule: ImageClassificationDataModule - self.hp: ExampleAlgorithm.HParams + super().__init__() + self.datamodule = datamodule + self.network = network + self.hp = hp or self.HParams() + self.automatic_optimization = True - # Initialize any lazy weights. + # Used by PL to compute the input/output shapes of the network. + self.example_input_array = torch.zeros( + (datamodule.batch_size, *datamodule.dims), device=self.device + ) + # Initialize any lazy weights. Necessary for distributed training and to infer shapes. _ = self.network(self.example_input_array) + # Save hyper-parameters. self.save_hyperparameters({"hp": dataclasses.asdict(self.hp)}) def forward(self, input: Tensor) -> Tensor: @@ -92,20 +100,10 @@ def shared_step( ) -> ClassificationOutputs: x, y = batch logits = self(x) - # reduction=none to get the proper gradients in a backward pass when using multiple gpus. - loss = F.cross_entropy(logits, y, reduction="none") + loss = F.cross_entropy(logits, y, reduction="mean") self.log(f"{phase}/loss", loss.detach().mean()) - - # probs = torch.softmax(logits, -1) - # accuracy = getattr(self, f"{phase}_accuracy") - # top5_accuracy = getattr(self, f"{phase}_top5_accuracy") - - # # TODO: It's a bit confusing, not sure if this is the right way to use this: - # accuracy(probs, y) - # top5_accuracy(probs, y) - # prog_bar = phase == "train" - # self.log(f"{phase}/accuracy", accuracy, prog_bar=prog_bar, sync_dist=True) - # self.log(f"{phase}/top5_accuracy", top5_accuracy, prog_bar=prog_bar, sync_dist=True) + acc = logits.detach().argmax(-1).eq(y).float().mean() + self.log(f"{phase}/accuracy", acc) return {"loss": loss, "logits": logits, "y": y} def configure_optimizers(self) -> dict: @@ -128,7 +126,9 @@ def configure_optimizers(self) -> dict: return optimizers def configure_callbacks(self) -> list[Callback]: - callbacks: list[Callback] = super().configure_callbacks() + callbacks: list[Callback] = [ + ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes) + ] if self.hp.early_stopping_patience != 0: # If early stopping is enabled, add a PL Callback for it: callbacks.append( diff --git a/project/algorithms/example_algo_test.py b/project/algorithms/example_algo_test.py index 55c67195..52b624b0 100644 --- a/project/algorithms/example_algo_test.py +++ b/project/algorithms/example_algo_test.py @@ -2,12 +2,12 @@ import torch -from project.algorithms.bases.image_classification_test import ImageClassificationAlgorithmTests +from project.algorithms.classification_tests import ClassificationAlgorithmTests from .example_algo import ExampleAlgorithm -class TestExampleAlgorithm(ImageClassificationAlgorithmTests[ExampleAlgorithm]): +class TestExampleAlgorithm(ClassificationAlgorithmTests[ExampleAlgorithm]): algorithm_type = ExampleAlgorithm algorithm_name: str = "example_algo" unsupported_datamodule_names: ClassVar[list[str]] = ["rl"] diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_algo.py index 5b527e7d..97becd04 100644 --- a/project/algorithms/jax_algo.py +++ b/project/algorithms/jax_algo.py @@ -6,9 +6,6 @@ import flax.linen import jax -import lightning -import lightning.pytorch -import lightning.pytorch.callbacks import rich import rich.logging import torch @@ -16,10 +13,12 @@ from lightning import Callback, Trainer from torch_jax_interop import WrappedJaxFunction, torch_to_jax -from project.algorithms.bases.algorithm import Algorithm +from project.algorithms.algorithm import Algorithm from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback -from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, +) from project.datamodules.image_classification.mnist import MNISTDataModule from project.utils.types import PhaseStr from project.utils.types.protocols import ClassificationDataModule @@ -196,11 +195,13 @@ def main(): logging.basicConfig( level=logging.INFO, format="%(message)s", handlers=[rich.logging.RichHandler()] ) + from lightning.pytorch.callbacks import RichProgressBar + trainer = Trainer( devices="auto", max_epochs=10, accelerator="auto", - callbacks=[lightning.pytorch.callbacks.RichProgressBar()], + callbacks=[RichProgressBar()], ) datamodule = MNISTDataModule(num_workers=4, batch_size=512) network = CNN(num_classes=datamodule.num_classes) diff --git a/project/algorithms/jax_algo_test.py b/project/algorithms/jax_algo_test.py index 2c5acd52..ed280101 100644 --- a/project/algorithms/jax_algo_test.py +++ b/project/algorithms/jax_algo_test.py @@ -6,7 +6,7 @@ from project.algorithms.jax_algo import JaxAlgorithm -from .bases.algorithm_test import AlgorithmTests +from .algorithm_test import AlgorithmTests class TestJaxAlgorithm(AlgorithmTests[JaxAlgorithm]): diff --git a/project/algorithms/manual_optimization_example.py b/project/algorithms/manual_optimization_example.py index 3941965e..317a6e1b 100644 --- a/project/algorithms/manual_optimization_example.py +++ b/project/algorithms/manual_optimization_example.py @@ -5,22 +5,23 @@ import torch from torch import Tensor, nn -from project.algorithms.bases.image_classification import ( +from project.algorithms.algorithm import Algorithm +from project.algorithms.callbacks.classification_metrics import ( + ClassificationMetricsCallback, ClassificationOutputs, - ImageClassificationAlgorithm, ) -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.types import PhaseStr -class ManualGradientsExample(ImageClassificationAlgorithm): +class ManualGradientsExample(Algorithm): """Example of an algorithm that calculates the gradients manually instead of having PL do the backward pass.""" @dataclass - class HParams(ImageClassificationAlgorithm.HParams): + class HParams(Algorithm.HParams): """Hyper-parameters of this example algorithm.""" lr: float = 0.1 @@ -34,7 +35,10 @@ def __init__( network: nn.Module, hp: ManualGradientsExample.HParams | None = None, ): - super().__init__(datamodule=datamodule, network=network, hp=hp or self.HParams()) + super().__init__() + self.datamodule = datamodule + self.network = network + self.hp = hp or self.HParams() # Just to let the type checker know the right type. self.hp: ManualGradientsExample.HParams @@ -44,6 +48,9 @@ def __init__( self.automatic_optimization = False # Instantiate any lazy weights with a dummy forward pass (optional). + self.example_input_array = torch.zeros( + (datamodule.batch_size, *datamodule.dims), device=self.device + ) self.network(self.example_input_array) def forward(self, x: Tensor) -> Tensor: @@ -62,14 +69,7 @@ def validation_step( def shared_step( self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr ) -> ClassificationOutputs: - """Performs a training/validation/test step. - - This must return a dictionary with at least the 'y' and 'logits' keys, and an optional - `loss` entry. This is so that the training of the model is easier to parallelize the - training across GPUs: - - the cross entropy loss gets calculated using the global batch size - - the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP) - """ + """Performs a training/validation/test step.""" x, y = batch logits = self(x) @@ -92,11 +92,10 @@ def shared_step( # NOTE: You don't need to call `loss.backward()`, you could also just set .grads # directly! - loss.backward() + self.manual_backward(loss) for name, parameter in self.named_parameters(): - if parameter.grad is None: - continue + assert parameter.grad is not None, name parameter.grad += self.hp.gradient_noise_std * torch.randn_like(parameter.grad) optimizer.step() @@ -106,3 +105,8 @@ def shared_step( def configure_optimizers(self): """Creates the optimizer(s) and learning rate scheduler(s).""" return torch.optim.SGD(self.parameters(), lr=self.hp.lr) + + def configure_callbacks(self): + return super().configure_callbacks() + [ + ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes) + ] diff --git a/project/algorithms/manual_optimization_example_test.py b/project/algorithms/manual_optimization_example_test.py index 8509df95..1185e153 100644 --- a/project/algorithms/manual_optimization_example_test.py +++ b/project/algorithms/manual_optimization_example_test.py @@ -2,14 +2,15 @@ import torch -from project.algorithms.bases.image_classification_test import ImageClassificationAlgorithmTests +from project.algorithms.classification_tests import ClassificationAlgorithmTests +from project.datamodules.vision import VisionDataModule from .manual_optimization_example import ManualGradientsExample -class TestManualOptimizationExample(ImageClassificationAlgorithmTests[ManualGradientsExample]): +class TestManualOptimizationExample(ClassificationAlgorithmTests[ManualGradientsExample]): algorithm_type = ManualGradientsExample algorithm_name: str = "manual_optimization" - unsupported_datamodule_names: ClassVar[list[str]] = ["rl"] + _supported_datamodule_types: ClassVar[list[type]] = [VisionDataModule] _supported_network_types: ClassVar[list[type]] = [torch.nn.Module] diff --git a/project/algorithms/no_op.py b/project/algorithms/no_op.py index d1b26ad3..029f8b16 100644 --- a/project/algorithms/no_op.py +++ b/project/algorithms/no_op.py @@ -4,8 +4,7 @@ from lightning import Callback from torch import nn -from project.algorithms.bases import Algorithm -from project.algorithms.bases.algorithm import StepOutputDict +from project.algorithms.algorithm import Algorithm, StepOutputDict from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback from project.utils.types import PhaseStr from project.utils.types.protocols import DataModule diff --git a/project/configs/network/__init__.py b/project/configs/network/__init__.py index 31723745..29c4721c 100644 --- a/project/configs/network/__init__.py +++ b/project/configs/network/__init__.py @@ -2,7 +2,6 @@ import torchvision.models from hydra_zen import store -from project.networks.fcnet import FcNet from project.utils.hydra_utils import interpolate_config_attribute network_store = store(group="network") @@ -14,13 +13,3 @@ ), name="resnet18", ) -network_store( - hydra_zen.builds( - FcNet, - hydra_convert="object", - hydra_recursive=True, - populate_full_signature=True, - output_dims=interpolate_config_attribute("datamodule.num_classes"), - ), - name="fcnet", -) diff --git a/project/configs/network/fcnet.yaml b/project/configs/network/fcnet.yaml new file mode 100644 index 00000000..539d1fc1 --- /dev/null +++ b/project/configs/network/fcnet.yaml @@ -0,0 +1,5 @@ +_target_: project.networks.fcnet.FcNet +output_dims: ${instance_attr:datamodule.num_classes} +input_shape: ${instance_attr:datamodule.dims} +hparams: + _target_: project.networks.fcnet.HParams diff --git a/project/configs/trainer/default.yaml b/project/configs/trainer/default.yaml index f5549142..f4cec3bc 100644 --- a/project/configs/trainer/default.yaml +++ b/project/configs/trainer/default.yaml @@ -2,7 +2,7 @@ _target_: lightning.Trainer logger: null accelerator: auto -strategy: null +strategy: auto devices: 1 min_epochs: 1 diff --git a/project/conftest.py b/project/conftest.py index 5cf5a914..bbf29d58 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import random import sys import typing import warnings @@ -21,7 +20,7 @@ from torch.utils.data import DataLoader from project.configs.config import Config -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.datamodules.vision import VisionDataModule @@ -38,9 +37,12 @@ from project.utils.testutils import ( default_marks_for_config_combinations, default_marks_for_config_name, + fork_rng, ) from project.utils.types import is_sequence_of -from project.utils.types.protocols import DataModule +from project.utils.types.protocols import ( + DataModule, +) if typing.TYPE_CHECKING: from _pytest.mark.structures import ParameterSet @@ -142,16 +144,10 @@ def seed(request: pytest.FixtureRequest): """Fixture that seeds everything for reproducibility and yields the random seed used.""" random_seed = getattr(request, "param", DEFAULT_SEED) assert isinstance(random_seed, int) or random_seed is None - - random_state = random.getstate() - np_random_state = np.random.get_state() - with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): + with fork_rng(): seed_everything(random_seed, workers=True) yield random_seed - random.setstate(random_state) - np.random.set_state(np_random_state) - @pytest.fixture(scope="session") def accelerator(request: pytest.FixtureRequest): diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index 40eb5928..72563061 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -1,6 +1,6 @@ -from .image_classification import ImageClassificationDataModule from .image_classification.cifar10 import CIFAR10DataModule, cifar10_normalization from .image_classification.fashion_mnist import FashionMNISTDataModule +from .image_classification.image_classification import ImageClassificationDataModule from .image_classification.imagenet import ImageNetDataModule from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization from .image_classification.inaturalist import INaturalistDataModule diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index e0e9666c..a2f07815 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -3,34 +3,88 @@ import matplotlib.pyplot as plt import pytest +import torch +from lightning import LightningDataModule +from lightning.fabric.utilities.exceptions import MisconfigurationException +from lightning.pytorch.trainer.states import RunningStage from tensor_regression.fixture import ( TensorRegressionFixture, get_test_source_and_temp_file_paths, ) from torch import Tensor +from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, +) +from project.datamodules.vision import VisionDataModule from project.utils.testutils import run_for_all_datamodules from project.utils.types import is_sequence_of -from ..utils.types.protocols import DataModule - +# @use_overrides(["datamodule.num_workers=0"]) # @pytest.mark.timeout(25, func_only=True) @pytest.mark.slow +@pytest.mark.parametrize( + "stage", + [ + RunningStage.TRAINING, + RunningStage.VALIDATING, + RunningStage.TESTING, + pytest.param( + RunningStage.PREDICTING, + marks=pytest.mark.xfail( + reason="Might not be implemented by the datamodule.", + raises=MisconfigurationException, + ), + ), + ], +) @run_for_all_datamodules() def test_first_batch( - datamodule: DataModule, + datamodule: LightningDataModule, request: pytest.FixtureRequest, tensor_regression: TensorRegressionFixture, original_datadir: Path, + stage: RunningStage, datadir: Path, ): # todo: skip this test if the dataset isn't already downloaded (for example on the GitHub CI). datamodule.prepare_data() - datamodule.setup("fit") + if stage == RunningStage.TRAINING: + datamodule.setup("fit") + dataloader = datamodule.train_dataloader() + elif stage in [RunningStage.VALIDATING, RunningStage.SANITY_CHECKING]: + datamodule.setup("validate") + dataloader = datamodule.val_dataloader() + elif stage == RunningStage.TESTING: + datamodule.setup("test") + dataloader = datamodule.test_dataloader() + else: + assert stage == RunningStage.PREDICTING + datamodule.setup("predict") + dataloader = datamodule.predict_dataloader() + + batch = next(iter(dataloader)) + + from torchvision.tv_tensors import Image + + if isinstance(datamodule, ImageClassificationDataModule): + assert isinstance(batch, list | tuple) and len(batch) == 2 + # todo: if we tighten this and make it so vision datamodules return Images, then we should + # have strict asserts here that check that batch[0] is an Image. It doesn't seem to be the case though. + # assert isinstance(batch[0], Image) + assert isinstance(batch[0], torch.Tensor) + assert isinstance(batch[1], torch.Tensor) + elif isinstance(datamodule, VisionDataModule): + if isinstance(batch, list | tuple): + # assert isinstance(batch[0], Image) + assert isinstance(batch[0], torch.Tensor) + else: + assert isinstance(batch, torch.Tensor) + assert isinstance(batch, Image) - batch = next(iter(datamodule.train_dataloader())) if isinstance(batch, dict): + # fixme: leftover from the RL datamodule proof-of-concept. if "infos" in batch: # todo: fix this, unsupported because of `object` dtype. batch.pop("infos") @@ -70,7 +124,14 @@ def test_first_batch( # moving mnist, y isn't a label, it's another image. axis.set_title(f"{index=}") - fig.suptitle(f"First batch of datamodule {type(datamodule).__name__}") + split = { + RunningStage.TRAINING: "training", + RunningStage.VALIDATING: "validation", + RunningStage.TESTING: "test", + RunningStage.PREDICTING: "prediction(?)", + } + + fig.suptitle(f"First {split[stage]} batch of datamodule {type(datamodule).__name__}") figure_path, _ = get_test_source_and_temp_file_paths( extension=".png", request=request, diff --git a/project/datamodules/datamodules_test/test_first_batch/cifar10_test.yaml b/project/datamodules/datamodules_test/test_first_batch/cifar10_test.yaml new file mode 100644 index 00000000..51af2350 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/cifar10_test.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 7631136576767235544 + max: 1.0 + mean: 0.468 + min: 0.0 + shape: + - 128 + - 3 + - 32 + - 32 + sum: 184156.109 +'1': + device: cpu + hash: 8462625093735455128 + max: 9 + mean: 4.703 + min: 0 + shape: + - 128 + sum: 602 diff --git a/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml b/project/datamodules/datamodules_test/test_first_batch/cifar10_train.yaml similarity index 100% rename from project/datamodules/datamodules_test/test_first_batch/cifar10.yaml rename to project/datamodules/datamodules_test/test_first_batch/cifar10_train.yaml diff --git a/project/datamodules/datamodules_test/test_first_batch/cifar10_validate.yaml b/project/datamodules/datamodules_test/test_first_batch/cifar10_validate.yaml new file mode 100644 index 00000000..cc9fa7bc --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/cifar10_validate.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 4180642819611736479 + max: 1.0 + mean: 0.463 + min: 0.0 + shape: + - 128 + - 3 + - 32 + - 32 + sum: 181864.641 +'1': + device: cpu + hash: -4539052997197868398 + max: 9 + mean: 4.258 + min: 0 + shape: + - 128 + sum: 545 diff --git a/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_test.yaml b/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_test.yaml new file mode 100644 index 00000000..9eba458f --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_test.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: -2751264324508784427 + max: 1.0 + mean: 0.292 + min: 0.0 + shape: + - 128 + - 1 + - 28 + - 28 + sum: 29317.309 +'1': + device: cpu + hash: 6530176971009424370 + max: 9 + mean: 4.461 + min: 0 + shape: + - 128 + sum: 571 diff --git a/project/datamodules/datamodules_test/test_first_batch/fashion_mnist.yaml b/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_train.yaml similarity index 100% rename from project/datamodules/datamodules_test/test_first_batch/fashion_mnist.yaml rename to project/datamodules/datamodules_test/test_first_batch/fashion_mnist_train.yaml diff --git a/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_validate.yaml b/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_validate.yaml new file mode 100644 index 00000000..587f7c8f --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/fashion_mnist_validate.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 225494219076660575 + max: 1.0 + mean: 0.296 + min: 0.0 + shape: + - 128 + - 1 + - 28 + - 28 + sum: 29740.449 +'1': + device: cpu + hash: -4543745818595514203 + max: 9 + mean: 4.453 + min: 0 + shape: + - 128 + sum: 570 diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet32_test.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet32_test.yaml new file mode 100644 index 00000000..0dc846b7 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet32_test.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: -5724309328014586573 + max: 1.0 + mean: 0.461 + min: 0.0 + shape: + - 64 + - 3 + - 32 + - 32 + sum: 90649.305 +'1': + device: cpu + hash: 2830952008253455204 + max: 987 + mean: 543.234 + min: 49 + shape: + - 64 + sum: 34767 diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet32_train.yaml similarity index 100% rename from project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml rename to project/datamodules/datamodules_test/test_first_batch/imagenet32_train.yaml diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet32_validate.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet32_validate.yaml new file mode 100644 index 00000000..5ffdf2fd --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet32_validate.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 4266338311425013668 + max: 1.0 + mean: 0.427 + min: 0.0 + shape: + - 64 + - 3 + - 32 + - 32 + sum: 83882.633 +'1': + device: cpu + hash: 5813156328689991827 + max: 973 + mean: 484.469 + min: 21 + shape: + - 64 + sum: 31006 diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet_test.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet_test.yaml new file mode 100644 index 00000000..884b4752 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet_test.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 8711678139956893479 + max: 2.64 + mean: -0.181 + min: -2.118 + shape: + - 64 + - 3 + - 224 + - 224 + sum: -1740804.5 +'1': + device: cpu + hash: -3826088756534882585 + max: 1 + mean: 0.219 + min: 0 + shape: + - 64 + sum: 14 diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet_train.yaml similarity index 100% rename from project/datamodules/datamodules_test/test_first_batch/imagenet.yaml rename to project/datamodules/datamodules_test/test_first_batch/imagenet_train.yaml diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet_validate.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet_validate.yaml new file mode 100644 index 00000000..87a7b5cc --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet_validate.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 976242047177418374 + max: 2.64 + mean: -0.118 + min: -2.118 + shape: + - 64 + - 3 + - 224 + - 224 + sum: -1139394.375 +'1': + device: cpu + hash: -5258163774450544391 + max: 0 + mean: 0.0 + min: 0 + shape: + - 64 + sum: 0 diff --git a/project/datamodules/datamodules_test/test_first_batch/mnist_test.yaml b/project/datamodules/datamodules_test/test_first_batch/mnist_test.yaml new file mode 100644 index 00000000..8361f2a8 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/mnist_test.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: -627002388361970843 + max: 1.0 + mean: 0.118 + min: 0.0 + shape: + - 128 + - 1 + - 28 + - 28 + sum: 11872.632 +'1': + device: cpu + hash: -7950905935926016059 + max: 9 + mean: 4.555 + min: 0 + shape: + - 128 + sum: 583 diff --git a/project/datamodules/datamodules_test/test_first_batch/mnist.yaml b/project/datamodules/datamodules_test/test_first_batch/mnist_train.yaml similarity index 100% rename from project/datamodules/datamodules_test/test_first_batch/mnist.yaml rename to project/datamodules/datamodules_test/test_first_batch/mnist_train.yaml diff --git a/project/datamodules/datamodules_test/test_first_batch/mnist_validate.yaml b/project/datamodules/datamodules_test/test_first_batch/mnist_validate.yaml new file mode 100644 index 00000000..27f56994 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/mnist_validate.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: 8917266713855133233 + max: 1.0 + mean: 0.135 + min: 0.0 + shape: + - 128 + - 1 + - 28 + - 28 + sum: 13531.503 +'1': + device: cpu + hash: -2353573324666895086 + max: 9 + mean: 4.328 + min: 0 + shape: + - 128 + sum: 554 diff --git a/project/datamodules/image_classification/__init__.py b/project/datamodules/image_classification/__init__.py index 7c3bca5f..44482982 100644 --- a/project/datamodules/image_classification/__init__.py +++ b/project/datamodules/image_classification/__init__.py @@ -1,3 +1,3 @@ -from .base import ImageClassificationDataModule +from .image_classification import ImageClassificationDataModule __all__ = ["ImageClassificationDataModule"] diff --git a/project/datamodules/image_classification/base.py b/project/datamodules/image_classification/base.py deleted file mode 100644 index 331cfbe6..00000000 --- a/project/datamodules/image_classification/base.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from torch import Tensor - -from project.datamodules.vision import VisionDataModule -from project.utils.types import C, H, W - -# todo: decide if this should be a protocol or an actual base class (currently a base class). - - -class ImageClassificationDataModule[BatchType: tuple[Tensor, Tensor]](VisionDataModule[BatchType]): - """Lightning data modules for image classification.""" - - num_classes: int - """Number of classes in the dataset.""" - - dims: tuple[C, H, W] - """A tuple describing the shape of the data.""" diff --git a/project/datamodules/image_classification/cifar10.py b/project/datamodules/image_classification/cifar10.py index 2d8b382c..91cbc92e 100644 --- a/project/datamodules/image_classification/cifar10.py +++ b/project/datamodules/image_classification/cifar10.py @@ -1,14 +1,15 @@ from __future__ import annotations from collections.abc import Callable -from typing import Any import torch from torchvision.datasets import CIFAR10 -from torchvision.transforms import v2 as transform_lib from torchvision.transforms import v2 as transforms -from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, +) +from project.datamodules.vision import VisionDataModule from project.utils.types import C, H, W @@ -20,6 +21,7 @@ def cifar10_train_transforms(): transforms.RandomCrop(size=32, padding=4, padding_mode="edge"), transforms.ToDtype(torch.float32, scale=True), cifar10_normalization(), + transforms.ToImage(), ] ) @@ -40,7 +42,7 @@ def cifar10_unnormalization(x: torch.Tensor) -> torch.Tensor: return (x * std) + mean -class CIFAR10DataModule(ImageClassificationDataModule): +class CIFAR10DataModule(ImageClassificationDataModule, VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png @@ -87,47 +89,6 @@ class CIFAR10DataModule(ImageClassificationDataModule): dims = (C(3), H(32), W(32)) num_classes = 10 - def __init__( - self, - data_dir: str | None = None, - val_split: int | float = 0.2, - num_workers: int | None = 0, - normalize: bool = False, - batch_size: int = 32, - seed: int = 42, - shuffle: bool = True, - pin_memory: bool = True, - drop_last: bool = False, - *args: Any, - **kwargs: Any, - ) -> None: - """ - Args: - data_dir: Where to save/load the data - val_split: Percent (float) or number (int) of samples to use for the validation split - num_workers: How many workers to use for loading data - normalize: If true applies image normalize - batch_size: How many samples per batch to load - seed: Random seed to be used for train/val/test splits - shuffle: If true shuffles the train data every epoch - pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before - returning them - drop_last: If true drops the last incomplete batch - """ - super().__init__( # type: ignore[misc] - data_dir=data_dir, - val_split=val_split, - num_workers=num_workers, - normalize=normalize, - batch_size=batch_size, - seed=seed, - shuffle=shuffle, - pin_memory=pin_memory, - drop_last=drop_last, - *args, - **kwargs, - ) - @property def num_samples(self) -> int: train_len, _ = self._get_splits(len_dataset=50_000) @@ -135,18 +96,19 @@ def num_samples(self) -> int: def default_transforms(self) -> Callable: if self.normalize: - cf10_transforms = transform_lib.Compose( + cf10_transforms = transforms.Compose( [ - transform_lib.ToImage(), - transform_lib.ToDtype(torch.float32, scale=True), + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), cifar10_normalization(), + transforms.ToImage(), # unsure if this is necessary. ] ) else: - cf10_transforms = transform_lib.Compose( + cf10_transforms = transforms.Compose( [ - transform_lib.ToImage(), - transform_lib.ToDtype(torch.float32, scale=True), + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), ] ) diff --git a/project/datamodules/image_classification/image_classification.py b/project/datamodules/image_classification/image_classification.py new file mode 100644 index 00000000..86961aac --- /dev/null +++ b/project/datamodules/image_classification/image_classification.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from torch import Tensor +from torchvision.tv_tensors import Image + +from project.datamodules.vision import VisionDataModule +from project.utils.types import C, H, W +from project.utils.types.protocols import ClassificationDataModule + +# todo: need to decide whether this should be a base class or just a protocol. +# - IF this is a protocol, then we can't use issubclass with it, so it can't be used in the +# `supported_datamodule_types` field on AlgorithmTests subclasses (for example `ClassificationAlgorithmTests`). + + +class ImageClassificationDataModule[BatchType: tuple[Image, Tensor]]( + VisionDataModule[BatchType], ClassificationDataModule[BatchType] +): + """Lightning data modules for image classification.""" + + num_classes: int + """Number of classes in the dataset.""" + + dims: tuple[C, H, W] + """A tuple describing the shape of the data.""" diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py index 60554038..c888e1a5 100644 --- a/project/datamodules/image_classification/imagenet.py +++ b/project/datamodules/image_classification/imagenet.py @@ -22,8 +22,8 @@ from torchvision.transforms import v2 as transform_lib from project.datamodules.vision import VisionDataModule -from project.utils.env_vars import DATA_DIR, NUM_WORKERS -from project.utils.types import C, H, StageStr, W +from project.utils.env_vars import DATA_DIR, NETWORK_DIR, NUM_WORKERS +from project.utils.types import C, H, W from project.utils.types.protocols import Module logger = get_logger(__name__) @@ -108,7 +108,7 @@ def __init__( drop_last=drop_last, train_transforms=train_transforms or self.train_transform(), val_transforms=val_transforms or self.val_transform(), - test_transforms=test_transforms, + test_transforms=test_transforms or self.test_transform(), **kwargs, ) self.dims = (C(3), H(self.image_size), W(self.image_size)) @@ -118,7 +118,16 @@ def __init__( # self.test_dataset_cls = UnlabeledImagenet def prepare_data(self) -> None: - network_imagenet_dir = Path("/network/datasets/imagenet") + if ( + not NETWORK_DIR + or not (network_imagenet_dir := NETWORK_DIR / "datasets" / "imagenet").exists() + ): + raise NotImplementedError( + "Assuming that the imagenet dataset can be found at " + "${NETWORK_DIR:-/network}/datasets/imagenet, (using $NETWORK_DIR if set, else " + "'/network'), but this path doesn't exist!" + ) + logger.debug(f"Preparing ImageNet train split in {self.data_dir}...") prepare_imagenet( self.data_dir, @@ -134,7 +143,7 @@ def prepare_data(self) -> None: super().prepare_data() - def setup(self, stage: StageStr | None = None) -> None: + def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None) -> None: logger.debug(f"Setup ImageNet datamodule for {stage=}") super().setup(stage) @@ -233,11 +242,15 @@ def val_transform(self) -> Callable: ] ) + # todo: what should be the default transformations for the test set? Same as validation, right? + test_transform = val_transform + def prepare_imagenet( root: Path, + *, split: Literal["train", "val"] = "train", - network_imagenet_dir: Path = Path("/network/datasets/imagenet"), + network_imagenet_dir: Path, ) -> None: """Custom preparation function for ImageNet, using @obilaniu's tar magic in Python form. diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index c66eb0f3..874a053a 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -7,7 +7,7 @@ from collections.abc import Callable, Sequence from logging import getLogger from pathlib import Path -from typing import ClassVar +from typing import ClassVar, Literal import gdown import numpy as np @@ -19,7 +19,7 @@ from project.datamodules.vision import VisionDataModule from project.utils.env_vars import DATA_DIR, SCRATCH -from project.utils.types import C, H, StageStr, W +from project.utils.types import C, H, W logger = getLogger(__name__) @@ -233,11 +233,8 @@ def prepare_data(self) -> None: """Saves files to data_dir.""" super().prepare_data() - def setup(self, stage: StageStr | None = None) -> None: - """Creates train, val, and test dataset.""" - if stage not in ["fit", "validate", "val", "test", None]: - raise ValueError(f"Invalid stage: {stage}") - + def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None) -> None: + # """Creates train, val, and test dataset.""" if stage: logger.debug(f"Setting up for stage {stage}") else: @@ -269,7 +266,7 @@ def setup(self, stage: StageStr | None = None) -> None: self.dataset_train = self._split_dataset(base_dataset_train, train=True) self.dataset_val = self._split_dataset(base_dataset_valid, train=False) - if stage in ["test", None]: + if stage in ["test", "predict", None]: test_transforms = self.test_transforms or self.default_transforms() self.dataset_test = self.dataset_cls( self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS diff --git a/project/datamodules/image_classification/inaturalist.py b/project/datamodules/image_classification/inaturalist.py index a1090a72..ca67f7ed 100644 --- a/project/datamodules/image_classification/inaturalist.py +++ b/project/datamodules/image_classification/inaturalist.py @@ -9,7 +9,9 @@ import torchvision.transforms as T from torchvision.datasets import INaturalist -from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, +) from project.utils.env_vars import DATA_DIR, NUM_WORKERS, SLURM_TMPDIR from project.utils.types import C, H, W diff --git a/project/datamodules/image_classification/inaturalist_test.py b/project/datamodules/image_classification/inaturalist_test.py index 9457e44a..7b9757f8 100644 --- a/project/datamodules/image_classification/inaturalist_test.py +++ b/project/datamodules/image_classification/inaturalist_test.py @@ -6,7 +6,7 @@ from torchvision import transforms as T from torchvision.datasets import INaturalist -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) diff --git a/project/datamodules/image_classification/mnist.py b/project/datamodules/image_classification/mnist.py index 905aa261..1b96f850 100644 --- a/project/datamodules/image_classification/mnist.py +++ b/project/datamodules/image_classification/mnist.py @@ -8,7 +8,9 @@ from torchvision.datasets import MNIST from torchvision.transforms import v2 as transforms -from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, +) from project.utils.types import C, H, W diff --git a/project/datamodules/vision.py b/project/datamodules/vision.py index c2d0f6b1..94e10fd4 100644 --- a/project/datamodules/vision.py +++ b/project/datamodules/vision.py @@ -6,16 +6,18 @@ from collections.abc import Callable from logging import getLogger as get_logger from pathlib import Path -from typing import ClassVar, Concatenate +from typing import ClassVar, Concatenate, Literal import torch from lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map from torchvision.datasets import VisionDataset from torchvision.transforms import v2 as transforms +from torchvision.tv_tensors import Image, set_return_type from project.utils.env_vars import DATA_DIR, NUM_WORKERS -from project.utils.types import C, H, StageStr, W +from project.utils.types import C, H, W from project.utils.types.protocols import DataModule logger = get_logger(__name__) @@ -128,8 +130,7 @@ def prepare_data(self) -> None: ) self.test_dataset_cls(str(self.data_dir), **test_kwargs) - def setup(self, stage: StageStr | None = None) -> None: - """Creates train, val, and test dataset.""" + def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None) -> None: if stage in ["fit", "validate"] or stage is None: logger.debug(f"creating training dataset with kwargs {self.train_kwargs}") dataset_train = self.dataset_cls( @@ -141,6 +142,11 @@ def setup(self, stage: StageStr | None = None) -> None: str(self.data_dir), **self.valid_kwargs, ) + + # TODO: If we support more datasets than image classification, we could add this: + # dataset_train = wrap_dataset_for_transforms_v2(dataset_train) + # dataset_val = wrap_dataset_for_transforms_v2(dataset_val) + # Train/validation split. # NOTE: the dataset is created twice (with the right transforms) and split in the same # way, such that there is no overlap in indices between train and validation sets. @@ -188,17 +194,15 @@ def train_dataloader[**P]( **kwargs: P.kwargs, ) -> DataLoader: """The train dataloader.""" + assert self.dataset_train is not None + kwargs = kwargs.copy() # type: ignore + kwargs.setdefault("shuffle", self.shuffle) + kwargs.setdefault("generator", torch.Generator().manual_seed(self.train_dl_rng_seed)) return self._data_loader( self.dataset_train, _dataloader_fn=_dataloader_fn, *args, - **( - dict( - shuffle=self.shuffle, - generator=torch.Generator().manual_seed(self.train_dl_rng_seed), - ) - | kwargs - ), + **kwargs, ) def val_dataloader[**P]( @@ -208,12 +212,14 @@ def val_dataloader[**P]( **kwargs: P.kwargs, ) -> DataLoader: """The val dataloader.""" - + assert self.dataset_val is not None + kwargs = kwargs.copy() # type: ignore + kwargs.setdefault("generator", torch.Generator().manual_seed(self.val_dl_rng_seed)) return self._data_loader( self.dataset_val, _dataloader_fn=_dataloader_fn, *args, - **(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs), + **kwargs, ) def test_dataloader[**P]( @@ -223,14 +229,14 @@ def test_dataloader[**P]( **kwargs: P.kwargs, ) -> DataLoader: """The test dataloader.""" - if self.dataset_test is None: - self.setup("test") assert self.dataset_test is not None + kwargs = kwargs.copy() # type: ignore + kwargs.setdefault("generator", torch.Generator().manual_seed(self.test_dl_rng_seed)) return self._data_loader( self.dataset_test, _dataloader_fn=_dataloader_fn, *args, - **(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs), + **kwargs, ) def _data_loader[**P]( @@ -246,13 +252,36 @@ def _data_loader[**P]( num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, - persistent_workers=(self.num_workers or 0) > 0, + persistent_workers=self.num_workers > 0, ) | dataloader_kwargs ) + return _dataloader_fn(dataset, *dataloader_args, **dataloader_kwargs) +def collate_images( + images: list[Image], + *, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, +): + with set_return_type("TVTensor"): + # note: Image is a subclass of Tensor, but list[Image] is not a subclass of list[Tensor] + image_batch: Image = collate_tensor_fn(images) # type: ignore + assert isinstance(image_batch, Image), type(image_batch) + + if image_batch.ndim <= 4: + # We wouldn't want to return an Image for higher-dimensions, it probably wouldn't make sense. + # log_once( + # message="Collating images into `torchvision.tv_tensors.Image`s", level=logging.INFO + # ) + return image_batch + return image_batch.as_subclass(torch.Tensor) + + +default_collate_fn_map[Image] = collate_images + + def _has_constructor_argument(cls: type[VisionDataset], arg: str) -> bool: # TODO: Would be more accurate to check if cls has either download or a **kwargs argument and # then check if the base class constructor takes a `download` argument. @@ -260,7 +289,7 @@ def _has_constructor_argument(cls: type[VisionDataset], arg: str) -> bool: # Check if sig has a **kwargs argument if arg in sig.parameters: return True - if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): + if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()) and cls.__base__ is not None: return _has_constructor_argument(cls.__base__, arg) return False diff --git a/project/experiment.py b/project/experiment.py index d437a9b1..a48f6390 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -19,12 +19,15 @@ from project.algorithms import Algorithm from project.configs.config import Config -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.hydra_utils import get_outer_class from project.utils.types import Dataclass -from project.utils.types.protocols import DataModule, Module +from project.utils.types.protocols import ( + DataModule, + Module, +) from project.utils.utils import validate_datamodule logger = get_logger(__name__) diff --git a/project/main.py b/project/main.py index c641cd2d..dc053068 100644 --- a/project/main.py +++ b/project/main.py @@ -14,7 +14,7 @@ from omegaconf import DictConfig from project.configs.config import Config -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.experiment import Experiment, setup_experiment diff --git a/project/networks/fcnet.py b/project/networks/fcnet.py index a6b9c013..e4558d62 100644 --- a/project/networks/fcnet.py +++ b/project/networks/fcnet.py @@ -1,46 +1,43 @@ -from __future__ import annotations +"""An example of a simple fully connected network.""" -from dataclasses import dataclass, field +from dataclasses import field from functools import singledispatch import numpy as np +import pydantic +import pydantic.generics import torch from torch import Tensor, nn -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) +from project.utils.types import FloatBetween0And1 from project.utils.types.protocols import DataModule class Flatten(nn.Flatten): def forward(self, input: Tensor): # NOTE: The input Should have at least 2 dimensions for `nn.Flatten` to work, but it isn't - # the case with a single observation from a single env. + # the case with a single observation from a single environment. if input.ndim <= 1: return input if input.is_nested: - # NOTE: This makes 2d inputs 3d on purpose so they can be used with a nn.Flatten. return torch.nested.as_nested_tensor( [input_i.reshape([input_i.shape[0], -1]) for input_i in input.unbind()] ) - if input.ndim == 3: - # FIXME: Hacky: don't collapse the `sequence length` dimension here. - # TODO: Perhaps use a named dimension to detect this case? - return input.reshape([input.shape[0], input.shape[1], -1]) return super().forward(input) class FcNet(nn.Sequential): - @dataclass - class HParams: + class HParams(pydantic.BaseModel): """Dataclass containing the network hyper-parameters.""" - hidden_dims: list[int] = field(default_factory=[128, 128].copy) + hidden_dims: list[pydantic.PositiveInt] = field(default_factory=[128, 128].copy) use_bias: bool = True - dropout_rate: float = 0.5 + dropout_rate: FloatBetween0And1 = 0.5 """Dropout rate. Set to 0 to disable dropout. diff --git a/project/networks/layers/layers.py b/project/networks/layers/layers.py index e582d531..8cd0c080 100644 --- a/project/networks/layers/layers.py +++ b/project/networks/layers/layers.py @@ -174,6 +174,8 @@ def forward(self, packed_inputs: tuple[Tensor, ...] | dict[str, Tensor]) -> OutT class Sample(Lambda, Module[[torch.distributions.Distribution], Tensor]): + """Layer that samples from a distribution.""" + def __init__(self, differentiable: bool = False) -> None: super().__init__(f=operator.methodcaller("rsample" if differentiable else "sample")) self._differentiable = differentiable diff --git a/project/utils/env_vars.py b/project/utils/env_vars.py index cc9d663d..d7baeba1 100644 --- a/project/utils/env_vars.py +++ b/project/utils/env_vars.py @@ -21,6 +21,12 @@ if (_network_dir := Path("/network")).exists() else None ) +"""The (read-only) network directory that contains datasets/weights/etc. + +todo: adapt this for the DRAC clusters. + +When running outside of the mila/DRAC clusters, this will be `None`, but can be mocked by setting the `NETWORK_DIR` environment variable. +""" REPO_ROOTDIR = Path(__file__).parent for level in range(5): diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 184a0605..49fae182 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -27,7 +27,7 @@ from torch.optim import Optimizer from project.configs import Config -from project.datamodules.image_classification import ( +from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.datamodules.vision import VisionDataModule @@ -35,7 +35,9 @@ from project.utils.env_vars import NETWORK_DIR from project.utils.hydra_utils import get_attr, get_outer_class from project.utils.types import PhaseStr -from project.utils.types.protocols import DataModule +from project.utils.types.protocols import ( + DataModule, +) from project.utils.utils import get_device logger = get_logger(__name__) @@ -386,7 +388,8 @@ def run_for_all_configs_in_group( k: default_marks_for_config_name.get(k, []) for k in get_all_configs_in_group(group_name) } - + # Parametrize the fixture (e.g. datamodule_name) indirectly, which will make it take each group + # member (e.g. datamodule config name), each with a parameterized mark. return pytest.mark.parametrize( f"{group_name}_name", [ @@ -624,7 +627,7 @@ def assert_no_nans_in_params_or_grads(module: nn.Module): @contextlib.contextmanager def fork_rng(): - with torch.random.fork_rng(): + with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): random_state = random.getstate() np_random_state = np.random.get_state() yield diff --git a/project/utils/types/__init__.py b/project/utils/types/__init__.py index 6858b489..381ed570 100644 --- a/project/utils/types/__init__.py +++ b/project/utils/types/__init__.py @@ -1,20 +1,11 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence -from typing import ( - Any, - Literal, - NewType, - TypeGuard, - Unpack, -) +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal, NewType, TypeGuard, Unpack +import annotated_types from torch import Tensor -from typing_extensions import ( - ParamSpec, - TypeVar, - TypeVarTuple, -) +from typing_extensions import TypeVar, TypeVarTuple from .protocols import Dataclass, DataModule, HasInputOutputShapes, Module @@ -24,22 +15,27 @@ W = NewType("W", int) S = NewType("S", int) -StageStr = Literal["fit", "validate", "test", "predict"] + +# todo: Fix this. Why do we have these enums? Are they necessary? Could we use the same ones as PL if we wanted to? +# from lightning.pytorch.trainer.states import RunningStage as PhaseStr +# from lightning.pytorch.trainer.states import TrainerFn as StageStr + PhaseStr = Literal["train", "val", "test"] """The trainer phases. TODO: There has to exist an enum for it somewhere in PyTorch Lightning. """ -P = ParamSpec("P", default=[Tensor]) -R = ParamSpec("R") +# Types used with pydantic: +FloatBetween0And1 = Annotated[float, annotated_types.Ge(0), annotated_types.Le(1)] + OutT = TypeVar("OutT", default=Tensor, covariant=True) Ts = TypeVarTuple("Ts", default=Unpack[tuple[Tensor, ...]]) T = TypeVar("T", default=Tensor) type NestedDict[K, V] = dict[K, V | NestedDict[K, V]] type NestedMapping[K, V] = Mapping[K, V | NestedMapping[K, V]] -type PyTree[T] = T | tuple[PyTree[T], ...] | list[PyTree[T]] | Mapping[Any, PyTree[T]] +type PyTree[T] = T | Iterable[PyTree[T]] | Mapping[Any, PyTree[T]] def is_list_of[V](object: Any, item_type: type[V] | tuple[type[V], ...]) -> TypeGuard[list[V]]: diff --git a/project/utils/types/protocols.py b/project/utils/types/protocols.py index 25c4bd4a..eae2f45a 100644 --- a/project/utils/types/protocols.py +++ b/project/utils/types/protocols.py @@ -3,13 +3,10 @@ import dataclasses import typing from collections.abc import Iterable -from typing import ClassVar, Protocol, runtime_checkable +from typing import ClassVar, Literal, Protocol, runtime_checkable from torch import nn -if typing.TYPE_CHECKING: - from project.utils.types import StageStr - class Dataclass(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field]] @@ -52,16 +49,14 @@ class HasInputOutputShapes(Module, Protocol): @runtime_checkable class DataModule[BatchType](Protocol): - """Protocol that shows the expected attributes / methods of the `LightningDataModule` class. + """Protocol that shows the minimal attributes / methods of the `LightningDataModule` class. This is used to type hint the batches that are yielded by the DataLoaders. """ - # batch_size: int - def prepare_data(self) -> None: ... - def setup(self, stage: StageStr) -> None: ... + def setup(self, stage: Literal["train", "validate", "test", "predict"]) -> None: ... def train_dataloader(self) -> Iterable[BatchType]: ... @@ -69,3 +64,10 @@ def train_dataloader(self) -> Iterable[BatchType]: ... @runtime_checkable class ClassificationDataModule[BatchType](DataModule[BatchType], Protocol): num_classes: int + + +# todo: Decide if we want this to be a base class or a protocol. Currently a base class. +# @runtime_checkable +# class ImageClassificationDataModule[BatchType](DataModule[BatchType], Protocol): +# num_classes: int +# dims: tuple[C, H, W] diff --git a/project/utils/utils.py b/project/utils/utils.py index 8f02c6b8..0702f0f1 100644 --- a/project/utils/utils.py +++ b/project/utils/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +import functools +import typing import warnings from collections.abc import Iterable, Mapping, Sequence from dataclasses import field @@ -17,16 +19,29 @@ from torch.nn.parameter import Parameter from torchvision import transforms -from project.datamodules.image_classification import ( - ImageClassificationDataModule, +from project.utils.types.protocols import ( + DataModule, + Module, ) -from project.utils.types.protocols import DataModule, Module from .types import NestedDict, NestedMapping logger = get_logger(__name__) +# todo: doesn't work? keeps logging each time! +@functools.cache +def log_once(message: str, level: int) -> None: + """Logs a message once per logger instance. The message is logged at the specified level. + + Args: + logger: The logger instance to use. + message: The message to log. + level: The logging level to use. + """ + logger.log(level=level, msg=message, stacklevel=2) + + def get_shape_ish(t: Tensor) -> tuple[int | Literal["?"], ...]: if not t.is_nested: return t.shape @@ -89,6 +104,15 @@ def get_devices(mod: Module) -> set[torch.device]: return set(p.device for p in mod.parameters()) +if typing.TYPE_CHECKING: + from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, + ) + + +# todo: shouldn't be here, should be done in `VisionDataModule` or in the configs: +# If `normalize=False`, and there is a normalization transform in the train transforms, then an +# error should be raised. def _remove_normalization_from_transforms( datamodule: ImageClassificationDataModule, ) -> None: @@ -116,6 +140,10 @@ def validate_datamodule[DM: DataModule | LightningDataModule](datamodule: DM) -> Returns the same datamodule. """ + from project.datamodules.image_classification.image_classification import ( + ImageClassificationDataModule, + ) + if isinstance(datamodule, ImageClassificationDataModule) and not datamodule.normalize: _remove_normalization_from_transforms(datamodule) else: