Skip to content

Commit

Permalink
Simplify examples, reduce typing verbosity (#8)
Browse files Browse the repository at this point in the history
* Simplify examples by reducing inheritance a bit

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix typing errors in samples_per_second.py

Signed-off-by: Fabrice Normandin <[email protected]>

* Simplify some of the algos some more

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove ImageClassificationAlgorithm, improve tests

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove the `bases` package, ImageClassification

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix import error

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix KeyError in GetMetricCallback

Signed-off-by: Fabrice Normandin <[email protected]>

* Add docstrings in Callback class

Signed-off-by: Fabrice Normandin <[email protected]>

* Tweak to Algorithm.shared_step_end

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix error in `Algorithm.shared_step_end`

Signed-off-by: Fabrice Normandin <[email protected]>

* Slightly simplify some of the test fixtures

Signed-off-by: Fabrice Normandin <[email protected]>

* Update regression files for the datamodules tests

Signed-off-by: Fabrice Normandin <[email protected]>

* Simplify FcNet config

Signed-off-by: Fabrice Normandin <[email protected]>

* Add missing MNIST regression files

Signed-off-by: Fabrice Normandin <[email protected]>

* Try to fix the "missing imagenet" error in CI

Signed-off-by: Fabrice Normandin <[email protected]>

* Add missing `NETWORK_DIR` in the devcontainer

Signed-off-by: Fabrice Normandin <[email protected]>

* Add comment

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Jun 27, 2024
1 parent a57ce73 commit 6a42ef0
Show file tree
Hide file tree
Showing 59 changed files with 847 additions and 604 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
},
"containerEnv": {
"SCRATCH": "/home/vscode/scratch",
"SLURM_TMPDIR": "/tmp"
"SLURM_TMPDIR": "/tmp",
"NETWORK_DIR": "/network"
},
"mounts": [
// https://code.visualstudio.com/remote/advancedcontainers/add-local-file-mount
Expand Down
3 changes: 2 additions & 1 deletion project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .algorithms import Algorithm, ExampleAlgorithm, ManualGradientsExample, NoOp
from .configs import Config
from .datamodules import ImageClassificationDataModule, VisionDataModule
from .datamodules import VisionDataModule
from .datamodules.image_classification.image_classification import ImageClassificationDataModule
from .experiment import Experiment
from .networks import FcNet

Expand Down
5 changes: 2 additions & 3 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from project.algorithms.jax_algo import JaxAlgorithm
from project.algorithms.no_op import NoOp

from .bases.algorithm import Algorithm
from .bases.image_classification import ImageClassificationAlgorithm
from .algorithm import Algorithm
from .example_algo import ExampleAlgorithm
from .manual_optimization_example import ManualGradientsExample

Expand All @@ -26,6 +25,6 @@
__all__ = [
"Algorithm",
"ExampleAlgorithm",
"ImageClassificationAlgorithm",
"ManualGradientsExample",
"JaxAlgorithm",
]
164 changes: 164 additions & 0 deletions project/algorithms/algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import NotRequired, TypedDict

import torch
from lightning import Callback, LightningModule, Trainer
from torch import Tensor
from typing_extensions import Generic, TypeVar # noqa

from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.utils.types import PhaseStr, PyTree
from project.utils.types.protocols import DataModule, Module


class StepOutputDict(TypedDict, total=False):
"""A dictionary that shows what an Algorithm can output from
`training/validation/test_step`."""

loss: NotRequired[Tensor | float]
"""Optional loss tensor that can be returned by those methods."""


BatchType = TypeVar("BatchType", bound=PyTree[torch.Tensor], contravariant=True)
# StepOutputType = TypeVar(
# "StepOutputType", bound=StepOutputDict | PyTree[torch.Tensor], covariant=True
# )
StepOutputType = TypeVar(
"StepOutputType",
bound=torch.Tensor | StepOutputDict,
default=StepOutputDict,
covariant=True,
)


class Algorithm(LightningModule, ABC, Generic[BatchType, StepOutputType]):
"""Base class for a learning algorithm.
This is an extension of the LightningModule class from PyTorch Lightning, with some common
boilerplate code to keep the algorithm implementations as simple as possible.
The networks themselves are created separately and passed as a constructor argument. This is
meant to make it easier to compare different learning algorithms on the same network
architecture.
"""

@dataclass
class HParams:
"""Hyper-parameters of the algorithm."""

def __init__(
self,
*,
datamodule: DataModule[BatchType] | None = None,
network: Module | None = None,
hp: HParams | None = None,
):
super().__init__()
self.datamodule = datamodule
self.network = network
self.hp = hp or self.HParams()
# fix for `self.device` property which defaults to cpu.
self._device = None

if isinstance(datamodule, ImageClassificationDataModule):
self.example_input_array = torch.zeros(
(datamodule.batch_size, *datamodule.dims), device=self.device
)

self.trainer: Trainer

def training_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
"""Performs a training step."""
return self.shared_step(batch=batch, batch_index=batch_index, phase="train")

def validation_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
"""Performs a validation step."""
return self.shared_step(batch=batch, batch_index=batch_index, phase="val")

def test_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
"""Performs a test step."""
return self.shared_step(batch=batch, batch_index=batch_index, phase="test")

def shared_step(self, batch: BatchType, batch_index: int, phase: PhaseStr) -> StepOutputType:
"""Performs a training/validation/test step.
This must return a nested dictionary of tensors matching the `StepOutputType` typedict for
this algorithm. By default,
`loss` entry. This is so that the training of the model is easier to parallelize the
training across GPUs:
- the cross entropy loss gets calculated using the global batch size
- the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP)
"""
raise NotImplementedError

@abstractmethod
def configure_optimizers(self):
# """Creates the optimizers and the learning rate schedulers."""'
# super().configure_optimizers()
...

def forward(self, x: Tensor) -> Tensor:
"""Performs a forward pass.
Feel free to overwrite this to do whatever you'd like.
"""
assert self.network is not None
return self.network(x)

def training_step_end(self, step_output: StepOutputDict) -> StepOutputDict:
"""Called with the results of each worker / replica's output.
See the `training_step_end` of pytorch-lightning for more info.
"""
return self.shared_step_end(step_output, phase="train")

def validation_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out:
return self.shared_step_end(step_output, phase="val")

def test_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out:
return self.shared_step_end(step_output, phase="test")

def shared_step_end[Out: torch.Tensor | StepOutputDict](
self, step_output: Out, phase: PhaseStr
) -> Out:
"""This is a default implementation for `[train/validation/test]_step_end`.
This does the following:
- Averages out the `loss` tensor if it was left unreduced.
- the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP)
"""

if (
isinstance(step_output, dict)
and isinstance((loss := step_output.get("loss")), torch.Tensor)
and loss.shape
):
# Replace the loss with its mean. This is useful when automatic
# optimization is enabled, for example in the example algo, where each replica
# returns the un-reduced cross-entropy loss. Here we need to reduce it to a scalar.
# todo: find out if this was already logged, to not log it twice.
# self.log(f"{phase}/loss", loss.mean(), sync_dist=True)
return step_output | {"loss": loss.mean()}

elif isinstance(step_output, torch.Tensor) and (loss := step_output).shape:
return loss.mean()

# self.log(f"{phase}/loss", torch.as_tensor(loss).mean(), sync_dist=True)
return step_output

def configure_callbacks(self) -> list[Callback]:
"""Use this to add some callbacks that should always be included with the model."""
return []

@property
def device(self) -> torch.device:
if self._device is None:
self._device = next((p.device for p in self.parameters()), torch.device("cpu"))
device = self._device
# make this more explicit to always include the index
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@
from collections.abc import Callable, Sequence
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any, ClassVar, Generic, Literal, TypeVar
from typing import Any, ClassVar, Literal

import pytest
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT
from omegaconf import DictConfig
from tensor_regression import TensorRegressionFixture
from torch import Tensor, nn
from torch.utils.data import DataLoader
from typing_extensions import ParamSpec

from project.algorithms.algorithm import Algorithm
from project.algorithms.callbacks.callback import Callback
from project.configs import Config, cs
from project.conftest import setup_hydra_for_tests_and_compose
from project.datamodules.image_classification import (
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.datamodules.vision import VisionDataModule
Expand All @@ -42,21 +44,15 @@
)
from project.utils.types.protocols import DataModule

from .algorithm import Algorithm

logger = get_logger(__name__)
P = ParamSpec("P")

AlgorithmType = TypeVar("AlgorithmType", bound=Algorithm)


SKIP_OR_XFAIL = pytest.xfail if "-vvv" in sys.argv else pytest.skip
"""Either skips the test entirely (default) or tries to run it and expect it to fail (slower)."""

skip_test = pytest.mark.xfail if "-vvv" in sys.argv else pytest.mark.skip


class AlgorithmTests(Generic[AlgorithmType]):
class AlgorithmTests[AlgorithmType: Algorithm]:
"""Unit tests for an algorithm class.
The algorithm creation is parametrized with all the datasets and all the networks, but the
Expand Down Expand Up @@ -112,7 +108,7 @@ def n_updates(self) -> int:
```python
@pytest.fixture
def n_updates(self, datamodule_name: str, network_name: str) -> int:
def n_updates(seor an actual classlf, datamodule_name: str, network_name: str) -> int:
if datamodule_name == "imagenet32":
return 10
return 3
Expand Down Expand Up @@ -141,6 +137,7 @@ def test_loss_is_reproducible(
)

def get_testing_callbacks(self) -> list[TestingCallback]:
"""Callbacks to be used for unit tests."""
return [
AllParamsShouldHaveGradients(),
]
Expand Down Expand Up @@ -207,7 +204,7 @@ def _train(
accelerator=accelerator,
default_root_dir=tmp_path,
callbacks=testing_callbacks.copy(), # type: ignore
# NOTE: Would be nice to be able to enforce this, but DTP uses nn.MaxUnpool2d.
# NOTE: Would be nice to be able to enforce this in general, but some algos could be using nn.MaxUnpool2d.
deterministic=True if can_use_deterministic_mode else "warn",
**trainer_kwargs,
)
Expand Down Expand Up @@ -293,7 +290,12 @@ def datamodule_name(self, request: pytest.FixtureRequest):
if datamodule_name in default_marks_for_config_name:
for marker in default_marks_for_config_name[datamodule_name]:
request.applymarker(marker)
self._skip_if_unsupported("datamodule", datamodule_name, skip_or_xfail=SKIP_OR_XFAIL)
# todo: if _supported_datamodule_types contains a protocol, this will raise a TypeError. In
# this case, we actually will use `_supported_datamodule_types` with `isinstance` instead.
try:
self._skip_if_unsupported("datamodule", datamodule_name, skip_or_xfail=SKIP_OR_XFAIL)
except TypeError:
pass
return datamodule_name

@pytest.fixture(params=get_all_network_names(), scope="class")
Expand Down Expand Up @@ -564,7 +566,12 @@ def on_train_batch_end(
batch: tuple[Tensor, Tensor],
batch_index: int,
) -> None:
assert self.metric in trainer.logged_metrics, (self.metric, trainer.logged_metrics.keys())
if self.metric not in trainer.logged_metrics:
logger.warning(
f"Unable to get the metric {self.metric} from the logged metrics: "
f"{trainer.logged_metrics.keys()} at step {trainer.global_step}."
)
return
metric_value = trainer.logged_metrics[self.metric]
assert isinstance(metric_value, Tensor)
self.metrics.append(metric_value.detach().item())
Expand Down Expand Up @@ -633,7 +640,6 @@ class AllParamsShouldHaveGradients(GetGradientsCallback):
def __init__(self, exceptions: Sequence[str] = ()) -> None:
super().__init__()
self.exceptions = exceptions

self.gradients: dict[str, Tensor] = {}

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
Expand All @@ -649,7 +655,7 @@ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule) -> Non
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
pl_module: Algorithm,
outputs: STEP_OUTPUT,
batch: Any,
batch_index: int,
Expand Down
4 changes: 0 additions & 4 deletions project/algorithms/bases/__init__.py

This file was deleted.

Loading

0 comments on commit 6a42ef0

Please sign in to comment.