From 67274449e0125b6c3e809ef7d67075e542f04233 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 30 May 2024 00:09:44 +0000 Subject: [PATCH 01/27] Add stub for imagenet datamodule Signed-off-by: Fabrice Normandin --- .../image_classification/imagenet.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 project/datamodules/image_classification/imagenet.py diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py new file mode 100644 index 00000000..a42fca04 --- /dev/null +++ b/project/datamodules/image_classification/imagenet.py @@ -0,0 +1,206 @@ +"""ImageNet datamodule adapted to the Mila cluster. + +Can be used either with a PyTorch-Lightning Trainer, or by itself to easily get efficient +dataloaders for the ImageNet dataset. + +Requirements (these are the versions I'm using, but this can probably be loosened a bit). +- pytorch-lightning==1.6.0 +- lightning-bolts==0.5 +""" + +from __future__ import annotations + +import os +import warnings +from collections.abc import Callable +from multiprocessing import cpu_count +from typing import Literal, NewType + +import torchvision.datasets as tvd +from lightning import Trainer + +# TODO: reimprot stuff from pl_bolts's ImageNet datamodule. +from pl_bolts.datasets import UnlabeledImagenet # noqa + +# from mila_datamodules.clusters import CURRENT_CLUSTER +# from mila_datamodules.vision.datasets import AdaptedDataset, ImageNet +# from pl_bolts.datamodules.imagenet_datamodule import ( +# ImagenetDataModule as _ImagenetDataModule, +# ) +# from pl_bolts.datasets import UnlabeledImagenet +# from pytorch_lightning import Trainer +from torch import nn +from torch.utils.data import DataLoader +from torchvision.datasets import ImageNet + +from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.datamodules.image_classification.inaturalist import get_slurm_tmpdir + +Stage = Literal["fit", "validate", "test", "predict"] + +C = NewType("C", int) +H = NewType("H", int) +W = NewType("W", int) +CURRENT_CLUSTER = os.environ.get("SLURM_CLUSTER") + + +class ImagenetDataModule(ImageClassificationDataModule): + """Imagenet DataModule adapted to the Mila cluster. + + - Copies/Extracts the datasets to the `$SLURM_TMPDIR/data/ImageNet` directory. + - Uses the right number of workers, depending on the SLURM configuration. + + TODO: Unclear why this couldn't just be a VisionDataModule using the ImageNet class. + TODO: Unclear why this `UnlabeledImageNet` dataset class is needed. + """ + + dataset_cls: type[tvd.ImageNet] = ImageNet + + def __init__( + self, + data_dir: str | None = None, + meta_dir: str | None = None, + num_imgs_per_val_class: int = 50, + image_size: int = 224, + num_workers: int | None = None, + batch_size: int = 32, + shuffle: bool = True, + pin_memory: bool = True, + persistent_workers: bool = False, + drop_last: bool = False, + train_transforms: Callable | nn.Module | None = None, + val_transforms: Callable | nn.Module | None = None, + test_transforms: Callable | nn.Module | None = None, + ) -> None: + if os.environ.get("SLURM_CLUSTER") == "mila": + slurm_tmpdir = get_slurm_tmpdir() + assert slurm_tmpdir + fixed_data_dir = str(slurm_tmpdir / "data" / "ImageNet") + if data_dir is not None and data_dir != fixed_data_dir: + warnings.warn( + RuntimeWarning( + f"Ignoring passed data_dir ({data_dir}), using {fixed_data_dir} instead." + ) + ) + data_dir = fixed_data_dir + else: + # Not on a SLURM cluster. `data_dir` must be provided (same as the base class). + # leave it as None, so the base class can raise the error. + pass + super().__init__( + data_dir=data_dir, # type: ignore + meta_dir=meta_dir, + num_imgs_per_val_class=num_imgs_per_val_class, + image_size=image_size, + num_workers=num_workers or num_cpus_to_use(), + batch_size=batch_size, + shuffle=shuffle, + pin_memory=pin_memory, + drop_last=drop_last, + ) + self.num_imgs_per_val_class = num_imgs_per_val_class + self.persistent_workers = persistent_workers + self._train_transforms = train_transforms + self._val_transforms = val_transforms + self._test_transforms = test_transforms + self.trainer: Trainer | None = None + + # NOTE: Do we want to store the dataset instances on the datamodule? Perhaps it could be + # useful for restoring state later or something? + # self.dataset_fit: UnlabeledImagenet | None = None + # self.dataset_validate: UnlabeledImagenet | None = None + # self.dataset_test: UnlabeledImagenet | None = None + # self.dataset_predict: UnlabeledImagenet | None = None + + def prepare_data(self) -> None: + """Prepares the data, copying the dataset to the SLURM temporary directory. + + NOTE: When using this datamodule without the PyTorch-Lightning Trainer, make sure to call + prepare_data() before calling train/val/test_dataloader(). + """ + if CURRENT_CLUSTER is not None: + # Create the dataset. + self.dataset_cls(split="train") + self.dataset_cls(split="val") + super().prepare_data() + + def train_dataloader(self) -> DataLoader: + # TODO: Use persistent_workers = True kwarg to DataLoader when num_workers > 0 and when + # in ddp_spawn mode. + from lightning.pytorch.strategies.ddp_spawn import DDPSpawnStrategy + + if self.trainer and isinstance(self.trainer.strategy, DDPSpawnStrategy): + # Use `persistent_workers=True` + # NOTE: Unfortunate that we have to copy all this code from the base class ;( + transforms = ( + self.train_transform() if self.train_transforms is None else self.train_transforms + ) + dataset = UnlabeledImagenet( + self.data_dir, + num_imgs_per_class=-1, + num_imgs_per_class_val_split=self.num_imgs_per_val_class, + meta_dir=self.meta_dir, + split="train", + transform=transforms, + ) + loader: DataLoader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + return loader + return super().train_dataloader() + + def val_dataloader(self) -> DataLoader: + return super().val_dataloader() + + def test_dataloader(self) -> DataLoader: + return super().test_dataloader() + + @property + def train_transforms(self) -> nn.Module | Callable | None: + return self._train_transforms + + @train_transforms.setter + def train_transforms(self, value: nn.Module | Callable | None): + self._train_transforms = value + + @property + def val_transforms(self) -> nn.Module | Callable | None: + return self._val_transforms + + @val_transforms.setter + def val_transforms(self, value: nn.Module | Callable | None): + self._val_transforms = value + + @property + def test_transforms(self) -> nn.Module | Callable | None: + return self._test_transforms + + @test_transforms.setter + def test_transforms(self, value: nn.Module | Callable | None): + self._test_transforms = value + + @property + def dims(self) -> tuple[C, H, W]: + """A tuple describing the shape of your data. Extra functionality exposed in ``size``. + + .. deprecated:: v1.5 Will be removed in v1.7.0. + """ + return self._dims + + @dims.setter + def dims(self, v: tuple[C, H, W]): + self._dims = v + + +def num_cpus_to_use() -> int: + if "SLURM_CPUS_PER_TASK" in os.environ: + return int(os.environ["SLURM_CPUS_PER_TASK"]) + if "SLURM_CPUS_ON_NODE" in os.environ: + return int(os.environ["SLURM_CPUS_ON_NODE"]) + return cpu_count() From 3359378c73e83882420b7e6a2ed0a1aeb4cb59c6 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 15:00:43 +0000 Subject: [PATCH 02/27] Add SLURM_TMPDIR in devcontainer, add notes Signed-off-by: Fabrice Normandin --- .devcontainer/devcontainer.json | 29 ++- project/datamodules/__init__.py | 2 + .../image_classification/imagenet.py | 206 ------------------ 3 files changed, 24 insertions(+), 213 deletions(-) delete mode 100644 project/datamodules/image_classification/imagenet.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 27ff0fbf..25923af4 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -59,15 +59,27 @@ } }, "containerEnv": { - "SCRATCH": "/home/vscode/scratch" + "SCRATCH": "/home/vscode/scratch", + "SLURM_TMPDIR": "/tmp" }, - // Mount a "$SCRATCH" directory in the host to ~/scratch in the container. - // Mount /network to use this to mount a "$SCRATCH" directory in the host to ~/scratch in the container. "mounts": [ // https://code.visualstudio.com/remote/advancedcontainers/add-local-file-mount - "source=${localEnv:HOME}/.cache/pdm,target=/home/vscode/.pdm_install_cache,type=bind,consistency=cached", - "source=${localEnv:SCRATCH},target=/home/vscode/scratch,type=bind,consistency=cached", - "source=${localEnv:NETWORK_DIR:/network},target=/network,type=bind,readonly" + // Mount a directory which will contain the pdm installation cache (shared with the host machine). + // This will use $SCRATCH/.cache/pdm, otherwise + "source=${localEnv:SCRATCH:~}/.cache/pdm,target=/home/vscode/.pdm_install_cache,type=bind,consistency=cached", + // Mount a "$SCRATCH" directory in the host to ~/scratch in the container. + // FIXME: This assumes that either the SCRATCH environment variable is set on the host, or + // that the $HOME/scratch directory exists. + "source=${localEnv:SCRATCH:~/scratch},target=/home/vscode/scratch,type=bind,consistency=cached", + // Mount a /network to match the /network directory on the host. + // FIXME: This assumes that either the NETWORK_DIR environment variable is set on the host, or + // that the /network directory exists. + "source=${localEnv:NETWORK_DIR:/network},target=/network,type=bind,readonly", + // Mount a /tmp on the host machine to /tmp/slurm_tmpdir in the container. + // note: there's also a SLURM_TMPDIR env variable set to /tmp/slurm_tmpdir in the container. + // NOTE: this assumes that either $SLURM_TMPDIR is set on the host machine (e.g. a compute node) + // or that `/tmp/slurm_tmpdir` exists on the host machine. + "source=${localEnv:SLURM_TMPDIR:/tmp/slurm_tmpdir},target=/tmp,type=bind,consistency=cached" ], "runArgs": [ "--gpus", @@ -76,7 +88,10 @@ ], // create the pdm cache dir on the host machine if it doesn exist yet so the mount above // doesn't fail. - "initializeCommand": "mkdir -p ~/.cache/pdm", + "initializeCommand": { + "create pdm install cache": "mkdir -p ~/.cache/pdm", // todo: put this on $SCRATCH on the host (e.g. compute node) + "create fake SLURM_TMPDIR": "mkdir -p /tmp/slurm_tmpdir" // this is fine on compute nodes + }, // NOTE: Getting some permission issues with the .cache dir if mounting .cache/pdm to // .cache/pdm in the container. Therefore, here I'm making a symlink from ~/.cache/pdm to // ~/.pdm_install_cache so the ~/.cache directory is writeable by the container. diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index 5e298407..ed4bb7b0 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -4,6 +4,7 @@ from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization from .image_classification.mnist import MNISTDataModule from .vision.base import VisionDataModule +from .vision.imagenet import ImageNetDataModule __all__ = [ "cifar10_normalization", @@ -12,6 +13,7 @@ "ImageClassificationDataModule", "imagenet32_normalization", "ImageNet32DataModule", + "ImageNetDataModule", "MNISTDataModule", "VisionDataModule", ] diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py deleted file mode 100644 index a42fca04..00000000 --- a/project/datamodules/image_classification/imagenet.py +++ /dev/null @@ -1,206 +0,0 @@ -"""ImageNet datamodule adapted to the Mila cluster. - -Can be used either with a PyTorch-Lightning Trainer, or by itself to easily get efficient -dataloaders for the ImageNet dataset. - -Requirements (these are the versions I'm using, but this can probably be loosened a bit). -- pytorch-lightning==1.6.0 -- lightning-bolts==0.5 -""" - -from __future__ import annotations - -import os -import warnings -from collections.abc import Callable -from multiprocessing import cpu_count -from typing import Literal, NewType - -import torchvision.datasets as tvd -from lightning import Trainer - -# TODO: reimprot stuff from pl_bolts's ImageNet datamodule. -from pl_bolts.datasets import UnlabeledImagenet # noqa - -# from mila_datamodules.clusters import CURRENT_CLUSTER -# from mila_datamodules.vision.datasets import AdaptedDataset, ImageNet -# from pl_bolts.datamodules.imagenet_datamodule import ( -# ImagenetDataModule as _ImagenetDataModule, -# ) -# from pl_bolts.datasets import UnlabeledImagenet -# from pytorch_lightning import Trainer -from torch import nn -from torch.utils.data import DataLoader -from torchvision.datasets import ImageNet - -from project.datamodules.image_classification.base import ImageClassificationDataModule -from project.datamodules.image_classification.inaturalist import get_slurm_tmpdir - -Stage = Literal["fit", "validate", "test", "predict"] - -C = NewType("C", int) -H = NewType("H", int) -W = NewType("W", int) -CURRENT_CLUSTER = os.environ.get("SLURM_CLUSTER") - - -class ImagenetDataModule(ImageClassificationDataModule): - """Imagenet DataModule adapted to the Mila cluster. - - - Copies/Extracts the datasets to the `$SLURM_TMPDIR/data/ImageNet` directory. - - Uses the right number of workers, depending on the SLURM configuration. - - TODO: Unclear why this couldn't just be a VisionDataModule using the ImageNet class. - TODO: Unclear why this `UnlabeledImageNet` dataset class is needed. - """ - - dataset_cls: type[tvd.ImageNet] = ImageNet - - def __init__( - self, - data_dir: str | None = None, - meta_dir: str | None = None, - num_imgs_per_val_class: int = 50, - image_size: int = 224, - num_workers: int | None = None, - batch_size: int = 32, - shuffle: bool = True, - pin_memory: bool = True, - persistent_workers: bool = False, - drop_last: bool = False, - train_transforms: Callable | nn.Module | None = None, - val_transforms: Callable | nn.Module | None = None, - test_transforms: Callable | nn.Module | None = None, - ) -> None: - if os.environ.get("SLURM_CLUSTER") == "mila": - slurm_tmpdir = get_slurm_tmpdir() - assert slurm_tmpdir - fixed_data_dir = str(slurm_tmpdir / "data" / "ImageNet") - if data_dir is not None and data_dir != fixed_data_dir: - warnings.warn( - RuntimeWarning( - f"Ignoring passed data_dir ({data_dir}), using {fixed_data_dir} instead." - ) - ) - data_dir = fixed_data_dir - else: - # Not on a SLURM cluster. `data_dir` must be provided (same as the base class). - # leave it as None, so the base class can raise the error. - pass - super().__init__( - data_dir=data_dir, # type: ignore - meta_dir=meta_dir, - num_imgs_per_val_class=num_imgs_per_val_class, - image_size=image_size, - num_workers=num_workers or num_cpus_to_use(), - batch_size=batch_size, - shuffle=shuffle, - pin_memory=pin_memory, - drop_last=drop_last, - ) - self.num_imgs_per_val_class = num_imgs_per_val_class - self.persistent_workers = persistent_workers - self._train_transforms = train_transforms - self._val_transforms = val_transforms - self._test_transforms = test_transforms - self.trainer: Trainer | None = None - - # NOTE: Do we want to store the dataset instances on the datamodule? Perhaps it could be - # useful for restoring state later or something? - # self.dataset_fit: UnlabeledImagenet | None = None - # self.dataset_validate: UnlabeledImagenet | None = None - # self.dataset_test: UnlabeledImagenet | None = None - # self.dataset_predict: UnlabeledImagenet | None = None - - def prepare_data(self) -> None: - """Prepares the data, copying the dataset to the SLURM temporary directory. - - NOTE: When using this datamodule without the PyTorch-Lightning Trainer, make sure to call - prepare_data() before calling train/val/test_dataloader(). - """ - if CURRENT_CLUSTER is not None: - # Create the dataset. - self.dataset_cls(split="train") - self.dataset_cls(split="val") - super().prepare_data() - - def train_dataloader(self) -> DataLoader: - # TODO: Use persistent_workers = True kwarg to DataLoader when num_workers > 0 and when - # in ddp_spawn mode. - from lightning.pytorch.strategies.ddp_spawn import DDPSpawnStrategy - - if self.trainer and isinstance(self.trainer.strategy, DDPSpawnStrategy): - # Use `persistent_workers=True` - # NOTE: Unfortunate that we have to copy all this code from the base class ;( - transforms = ( - self.train_transform() if self.train_transforms is None else self.train_transforms - ) - dataset = UnlabeledImagenet( - self.data_dir, - num_imgs_per_class=-1, - num_imgs_per_class_val_split=self.num_imgs_per_val_class, - meta_dir=self.meta_dir, - split="train", - transform=transforms, - ) - loader: DataLoader = DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) - return loader - return super().train_dataloader() - - def val_dataloader(self) -> DataLoader: - return super().val_dataloader() - - def test_dataloader(self) -> DataLoader: - return super().test_dataloader() - - @property - def train_transforms(self) -> nn.Module | Callable | None: - return self._train_transforms - - @train_transforms.setter - def train_transforms(self, value: nn.Module | Callable | None): - self._train_transforms = value - - @property - def val_transforms(self) -> nn.Module | Callable | None: - return self._val_transforms - - @val_transforms.setter - def val_transforms(self, value: nn.Module | Callable | None): - self._val_transforms = value - - @property - def test_transforms(self) -> nn.Module | Callable | None: - return self._test_transforms - - @test_transforms.setter - def test_transforms(self, value: nn.Module | Callable | None): - self._test_transforms = value - - @property - def dims(self) -> tuple[C, H, W]: - """A tuple describing the shape of your data. Extra functionality exposed in ``size``. - - .. deprecated:: v1.5 Will be removed in v1.7.0. - """ - return self._dims - - @dims.setter - def dims(self, v: tuple[C, H, W]): - self._dims = v - - -def num_cpus_to_use() -> int: - if "SLURM_CPUS_PER_TASK" in os.environ: - return int(os.environ["SLURM_CPUS_PER_TASK"]) - if "SLURM_CPUS_ON_NODE" in os.environ: - return int(os.environ["SLURM_CPUS_ON_NODE"]) - return cpu_count() From 9956f93c87d6addc8edd77253ef3bf60a1aefcba Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 15:06:11 +0000 Subject: [PATCH 03/27] Add Olexa's imagenet recipe in python form Signed-off-by: Fabrice Normandin --- project/datamodules/vision/imagenet.py | 173 +++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 project/datamodules/vision/imagenet.py diff --git a/project/datamodules/vision/imagenet.py b/project/datamodules/vision/imagenet.py new file mode 100644 index 00000000..2ea7f423 --- /dev/null +++ b/project/datamodules/vision/imagenet.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import contextlib +import os +import shutil +import tarfile +import time +from collections.abc import Callable +from logging import getLogger as get_logger +from pathlib import Path +from typing import ClassVar, Concatenate, Literal, TypeVar + +import torch +import tqdm +from torchvision.datasets import ImageNet + +from project.configs.datamodule import DATA_DIR +from project.datamodules.vision.base import VisionDataModule +from project.utils.types import C, H, StageStr, W +from project.utils.types.protocols import Module + +logger = get_logger(__name__) +ImageNetType = TypeVar("ImageNetType", bound=ImageNet) + + +@contextlib.contextmanager +def change_directory(path: Path): + curdir = Path.cwd() + os.chdir(path) + yield + os.chdir(curdir) + + +class ImageNetDataModule(VisionDataModule): + name: ClassVar[str] = "imagenet" + """Dataset name.""" + + dataset_cls: ClassVar[type[ImageNet]] = ImageNet + """Dataset class to use.""" + + dims: tuple[C, H, W] = (C(3), H(224), W(224)) + """A tuple describing the shape of the data.""" + + num_classes: ClassVar[int] = 1000 + + def __init__( + self, + root: str | Path = DATA_DIR, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.root = Path(root) + + def prepare_data(self) -> None: + network_imagenet_dir = Path("/network/datasets/imagenet") + assert network_imagenet_dir.exists() + prepare_imagenet( + self.root, network_imagenet_dir=network_imagenet_dir, split="train", **self.EXTRA_ARGS + ) + prepare_imagenet( + self.root, network_imagenet_dir=network_imagenet_dir, split="val", **self.EXTRA_ARGS + ) + + def setup(self, stage: StageStr | None = None) -> None: + super().setup(stage) + + def default_transforms(self) -> Module[[torch.Tensor], torch.Tensor]: + from torchvision.models.resnet import ResNet152_Weights + + return ResNet152_Weights.IMAGENET1K_V1.transforms + + +def prepare_imagenet[**P]( + root: str | Path, + split: Literal["train", "val"] = "train", + network_imagenet_dir: Path = Path("/network/datasets/imagenet"), + _dataset: Callable[Concatenate[str, Literal["train", "val"], P], ImageNet] = ImageNet, + *args: P.args, + **kwargs: P.kwargs, +) -> ImageNet: + """Custom preparation function for ImageNet, using @obilaniu's tar magic in Python form. + + The core of this is equivalent to these bash commands: + + ```bash + mkdir -p $SLURM_TMPDIR/imagenet/val + cd $SLURM_TMPDIR/imagenet/val + tar -xf /network/scratch/b/bilaniuo/ILSVRC2012_img_val.tar + mkdir -p $SLURM_TMPDIR/imagenet/train + cd $SLURM_TMPDIR/imagenet/train + tar -xf /network/datasets/imagenet/ILSVRC2012_img_train.tar \ + --to-command='mkdir ${TAR_REALNAME%.tar}; tar -xC ${TAR_REALNAME%.tar}' + ``` + """ + root = Path(root) + if not network_imagenet_dir.exists(): + raise NotImplementedError( + f"Assuming that we're running on the Mila cluster where {network_imagenet_dir} exists for now." + ) + val_archive_file_name = "ILSVRC2012_img_val.tar" + train_archive_file_name = "ILSVRC2012_img_train.tar" + devkit_file_name = "ILSVRC2012_devkit_t12.tar.gz" + md5sums_file_name = "md5sums" + + def _symlink_if_needed(filename: str, network_imagenet_dir: Path): + symlink = root / filename + if not symlink.exists(): + symlink.symlink_to(network_imagenet_dir / filename) + + # Create a symlink to the archive in $SLURM_TMPDIR, because torchvision expects it to be + # there. + _symlink_if_needed(train_archive_file_name, network_imagenet_dir) + _symlink_if_needed(val_archive_file_name, network_imagenet_dir) + _symlink_if_needed(devkit_file_name, network_imagenet_dir) + # TODO: COPY the file, not symlink it! (otherwise we get some "Read-only filesystem" errors + # when calling tvd.ImageNet(...). (Probably because the constructor tries to open the file) + # _symlink_if_needed(md5sums_file_name, network_imagenet_dir) + md5sums_file = root / md5sums_file_name + if not md5sums_file.exists(): + shutil.copyfile(network_imagenet_dir / md5sums_file_name, md5sums_file) + md5sums_file.chmod(0o755) + + logger.info("Extracting the ImageNet archives using Olexa's tar magic in python form...") + + if split == "train": + train_dir = root / "train" + train_dir.mkdir(exist_ok=True, parents=True) + + # The ImageNet train archive is a tarfile of tarfiles (one for each class). + with tarfile.open(network_imagenet_dir / train_archive_file_name) as train_tarfile: + for member in tqdm.tqdm( + train_tarfile, + total=1000, # hard-coded here, since we know there are 1000 folders. + desc="Extracting train archive", + unit="Directories", + position=0, + ): + buffer = train_tarfile.extractfile(member) + assert buffer is not None + subdir = train_dir / member.name.replace(".tar", "") + subdir.mkdir(mode=0o755, parents=True, exist_ok=True) + files_in_subdir = set(p.name for p in subdir.iterdir()) + with tarfile.open(fileobj=buffer, mode="r|*") as sub_tarfile: + for tarinfo in sub_tarfile: + if tarinfo.name in files_in_subdir: + # Image file is already in the directory. + continue + sub_tarfile.extract(tarinfo, subdir) + + else: + val_dir = root / "val" + val_dir.mkdir(exist_ok=True, parents=True) + with tarfile.open(network_imagenet_dir / val_archive_file_name) as val_tarfile: + val_tarfile.extractall(val_dir) + + return _dataset(str(root), split, *args, **kwargs) + + +def main(): + slurm_tmpdir = Path(os.environ["SLURM_TMPDIR"]) + datamodule = ImageNetDataModule(slurm_tmpdir) + start = time.time() + datamodule.prepare_data() + datamodule.setup("fit") + dl = datamodule.train_dataloader() + _batch = next(iter(dl)) + end = time.time() + print(f"Prepared imagenet in {end-start:.2f}s.") + + +if __name__ == "__main__": + main() From c135864850bf1820dea311455bed581a86f997e9 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 18:34:27 +0000 Subject: [PATCH 04/27] Add DataModule for the ImageNet dataset Signed-off-by: Fabrice Normandin --- project/configs/datamodule/__init__.py | 5 +- project/datamodules/__init__.py | 2 +- .../image_classification/imagenet.py | 422 ++++++++++++++++++ project/datamodules/vision/base.py | 88 ++-- project/datamodules/vision/imagenet.py | 173 ------- 5 files changed, 452 insertions(+), 238 deletions(-) create mode 100644 project/datamodules/image_classification/imagenet.py delete mode 100644 project/datamodules/vision/imagenet.py diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index 03154b43..541a35e1 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -23,6 +23,7 @@ Version, ) from project.datamodules.image_classification.mnist import mnist_train_transforms +from project.datamodules.vision.base import SLURM_TMPDIR FILE = Path(__file__) REPO_ROOTDIR = FILE.parent @@ -31,10 +32,6 @@ break REPO_ROOTDIR = REPO_ROOTDIR.parent - -SLURM_TMPDIR: Path | None = ( - Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None -) SLURM_JOB_ID: int | None = ( int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None ) diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index ed4bb7b0..52e8d92c 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -1,10 +1,10 @@ from .image_classification import ImageClassificationDataModule from .image_classification.cifar10 import CIFAR10DataModule, cifar10_normalization from .image_classification.fashion_mnist import FashionMNISTDataModule +from .image_classification.imagenet import ImageNetDataModule from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization from .image_classification.mnist import MNISTDataModule from .vision.base import VisionDataModule -from .vision.imagenet import ImageNetDataModule __all__ = [ "cifar10_normalization", diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py new file mode 100644 index 00000000..c4297495 --- /dev/null +++ b/project/datamodules/image_classification/imagenet.py @@ -0,0 +1,422 @@ +from __future__ import annotations + +import logging +import math +import os +import shutil +import tarfile +import time +from collections import defaultdict +from collections.abc import Callable +from logging import getLogger as get_logger +from pathlib import Path +from typing import ClassVar, Literal + +import rich +import rich.logging +import torch +import torch.utils.data +import tqdm +from torchvision.datasets import ImageNet +from torchvision.models.resnet import ResNet152_Weights +from torchvision.transforms import v2 as transform_lib + +from project.datamodules.vision.base import VisionDataModule +from project.utils.types import C, H, StageStr, W +from project.utils.types.protocols import Module + +logger = get_logger(__name__) + + +def imagenet_normalization(): + return transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +type ClassIndex = int +type ImageIndex = int + + +class ImageNetDataModule(VisionDataModule): + """ImageNet datamodule. + + Extracted from https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/datamodules/imagenet_datamodule.py + - Made this a subclass of VisionDataModule + + Notes: + - train_dataloader uses the train split of imagenet2012 and puts away a portion of it for the validation split. + - val_dataloader uses the part of the train split of imagenet2012 that was not used for training via + `num_imgs_per_val_class` + - TODO: needs to pass split='val' to UnlabeledImagenet. + - test_dataloader uses the validation split of imagenet2012 for testing. + - TODO: need to pass num_imgs_per_class=-1 for test dataset and split="test". + """ + + name: ClassVar[str] = "imagenet" + """Dataset name.""" + + dataset_cls: ClassVar[type[ImageNet]] = ImageNet + """Dataset class to use.""" + + dims: tuple[C, H, W] = (C(3), H(224), W(224)) + """A tuple describing the shape of the data.""" + + num_classes: ClassVar[int] = 1000 + + def __init__( + self, + data_dir: str | Path | None = None, + *, + val_split: int + | float = 0.01, # save `val_split`% of the training data *of each class* for validation. + num_workers: int | None = None, + normalize: bool = False, + image_size: int = 224, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = True, + pin_memory: bool = True, + drop_last: bool = False, + train_transforms: Callable | None = None, + val_transforms: Callable | None = None, + test_transforms: Callable | None = None, + **kwargs, + ): + """Creates an ImageNet datamodule (doesn't load or prepare the dataset yet). + + Parameters + ---------- + data_dir: path to the imagenet dataset file + val_split: save `val_split`% of the training data *of each class* for validation. + image_size: final image size + num_workers: how many data workers + batch_size: batch_size + shuffle: If true shuffles the data every epoch + pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before \ + returning them + drop_last: If true drops the last incomplete batch + """ + self.image_size = image_size + super().__init__( + data_dir, + num_workers=num_workers, + val_split=val_split, + shuffle=shuffle, + pin_memory=pin_memory, + normalize=normalize, + seed=seed, + batch_size=batch_size, + drop_last=drop_last, + train_transforms=train_transforms or self.train_transform(), + val_transforms=val_transforms or self.val_transform(), + test_transforms=test_transforms, + **kwargs, + ) + self.dims = (C(3), H(self.image_size), W(self.image_size)) + self.train_kwargs = self.train_kwargs | {"split": "train"} + self.valid_kwargs = self.valid_kwargs | {"split": "train"} + self.test_kwargs = self.test_kwargs | {"split": "val"} + # self.test_dataset_cls = UnlabeledImagenet + + def prepare_data(self) -> None: + network_imagenet_dir = Path("/network/datasets/imagenet") + logger.debug(f"Preparing ImageNet train split in {self.data_dir}...") + prepare_imagenet( + self.data_dir, + network_imagenet_dir=network_imagenet_dir, + split="train", + ) + logger.debug(f"Preparing ImageNet val (test) split in {self.data_dir}...") + prepare_imagenet( + self.data_dir, + network_imagenet_dir=network_imagenet_dir, + split="val", + ) + + super().prepare_data() + + def setup(self, stage: StageStr | None = None) -> None: + logger.debug(f"Setup ImageNet datamodule for {stage=}") + super().setup(stage) + + def _split_dataset(self, dataset: ImageNet, train: bool = True) -> torch.utils.data.Dataset: + class_item_indices: dict[ClassIndex, list[ImageIndex]] = defaultdict(list) + for dataset_index, y in enumerate(dataset.targets): + class_item_indices[y].append(dataset_index) + + train_val_split_seed = self.seed + gen = torch.Generator().manual_seed(train_val_split_seed) + + train_class_indices: dict[ClassIndex, list[ImageIndex]] = {} + valid_class_indices: dict[ClassIndex, list[ImageIndex]] = {} + + for label, dataset_indices in class_item_indices.items(): + num_images_in_class = len(dataset_indices) + num_valid = math.ceil(self.val_split * num_images_in_class) + num_train = num_images_in_class - num_valid + + permutation = torch.randperm(len(dataset_indices), generator=gen) + dataset_indices = torch.tensor(dataset_indices)[permutation].tolist() + + train_indices = dataset_indices[:num_train] + valid_indices = dataset_indices[num_train:] + + train_class_indices[label] = train_indices + valid_class_indices[label] = valid_indices + + all_train_indices = sum(train_class_indices.values(), []) + all_valid_indices = sum(valid_class_indices.values(), []) + train_dataset = torch.utils.data.Subset(dataset, all_train_indices) + valid_dataset = torch.utils.data.Subset(dataset, all_valid_indices) + if train: + return train_dataset + return valid_dataset + + def _verify_splits(self, data_dir: str | Path, split: str) -> None: + dirs = os.listdir(data_dir) + if split not in dirs: + raise FileNotFoundError( + f"a {split} Imagenet split was not found in {data_dir}," + f" make sure the folder contains a subfolder named {split}" + ) + + def default_transforms(self) -> Module[[torch.Tensor], torch.Tensor]: + return ResNet152_Weights.IMAGENET1K_V1.transforms + + def train_transform(self) -> Module[[torch.Tensor], torch.Tensor]: + """The standard imagenet transforms. + + .. code-block:: python + + transform_lib.Compose([ + transform_lib.RandomResizedCrop(self.image_size), + transform_lib.RandomHorizontalFlip(), + transform_lib.ToTensor(), + transform_lib.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + return transform_lib.Compose( + [ + transform_lib.RandomResizedCrop(self.image_size), + transform_lib.RandomHorizontalFlip(), + transform_lib.ToTensor(), + imagenet_normalization(), + ] + ) + + def val_transform(self) -> Callable: + """The standard imagenet transforms for validation. + + .. code-block:: python + + transform_lib.Compose([ + transform_lib.Resize(self.image_size + 32), + transform_lib.CenterCrop(self.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + + return transform_lib.Compose( + [ + transform_lib.Resize(self.image_size + 32), + transform_lib.CenterCrop(self.image_size), + transform_lib.ToTensor(), + imagenet_normalization(), + ] + ) + + +def prepare_imagenet( + root: Path, + split: Literal["train", "val"] = "train", + network_imagenet_dir: Path = Path("/network/datasets/imagenet"), +) -> None: + """Custom preparation function for ImageNet, using @obilaniu's tar magic in Python form. + + The core of this is equivalent to these bash commands: + + ```bash + mkdir -p $SLURM_TMPDIR/imagenet/val + cd $SLURM_TMPDIR/imagenet/val + tar -xf /network/scratch/b/bilaniuo/ILSVRC2012_img_val.tar + mkdir -p $SLURM_TMPDIR/imagenet/train + cd $SLURM_TMPDIR/imagenet/train + tar -xf /network/datasets/imagenet/ILSVRC2012_img_train.tar \ + --to-command='mkdir ${TAR_REALNAME%.tar}; tar -xC ${TAR_REALNAME%.tar}' + ``` + """ + if not network_imagenet_dir.exists(): + raise NotImplementedError( + f"Assuming that we're running on a cluster where {network_imagenet_dir} exists for now." + ) + val_archive_file_name = "ILSVRC2012_img_val.tar" + train_archive_file_name = "ILSVRC2012_img_train.tar" + devkit_file_name = "ILSVRC2012_devkit_t12.tar.gz" + md5sums_file_name = "md5sums" + + def _symlink_if_needed(filename: str, network_imagenet_dir: Path): + if not (symlink := root / filename).exists(): + symlink.symlink_to(network_imagenet_dir / filename) + + # Create a symlink to the archive in $SLURM_TMPDIR, because torchvision expects it to be + # there. + _symlink_if_needed(train_archive_file_name, network_imagenet_dir) + _symlink_if_needed(val_archive_file_name, network_imagenet_dir) + _symlink_if_needed(devkit_file_name, network_imagenet_dir) + # TODO: COPY the file, not symlink it! (otherwise we get some "Read-only filesystem" errors + # when calling tvd.ImageNet(...). (Probably because the constructor tries to open the file) + # _symlink_if_needed(md5sums_file_name, network_imagenet_dir) + md5sums_file = root / md5sums_file_name + if not md5sums_file.exists(): + shutil.copyfile(network_imagenet_dir / md5sums_file_name, md5sums_file) + md5sums_file.chmod(0o755) + + if split == "train": + train_dir = root / "train" + train_dir.mkdir(exist_ok=True, parents=True) + train_archive = network_imagenet_dir / train_archive_file_name + _extract_train_archive( + train_archive=train_archive, + train_dir=train_dir, + previously_extracted_dirs_file=root / "previously_extracted_dirs.txt", + ) + + # OR: could just reuse the equivalent-ish from torchvision, but which doesn't support + # resuming after an interrupt. + # from torchvision.datasets.imagenet import parse_train_archive + # parse_train_archive(root, file=train_archive_file_name, folder="train") + else: + from torchvision.datasets.imagenet import ( + load_meta_file, + parse_devkit_archive, + parse_val_archive, + ) + + parse_devkit_archive(root, file=devkit_file_name) + wnids = load_meta_file(root)[1] + val_dir = root / "val" + if not val_dir.exists(): + logger.debug(f"Extracting ImageNet test set to {val_dir}") + parse_val_archive(root, file=val_archive_file_name, wnids=wnids) + return + + logger.debug(f"listing the contents of {val_dir}") + children = list(val_dir.iterdir()) + + if not children: + logger.debug(f"Extracting ImageNet test set to {val_dir}") + parse_val_archive(root, file=val_archive_file_name, wnids=wnids) + return + + if all(child.is_dir() for child in children): + logger.info("Validation split already extracted. Skipping.") + return + + logger.warning( + f"Incomplete extraction of the ImageNet test set in {val_dir}, deleting it and extracting again." + ) + shutil.rmtree(root / "val", ignore_errors=False) + parse_val_archive(root, file=val_archive_file_name, wnids=wnids) + + # val_dir = root / "val" + # val_dir.mkdir(exist_ok=True, parents=True) + # with tarfile.open(network_imagenet_dir / val_archive_file_name) as val_tarfile: + # val_tarfile.extractall(val_dir) + + +def _extract_train_archive( + *, train_archive: Path, train_dir: Path, previously_extracted_dirs_file: Path +) -> None: + # The ImageNet train archive is a tarfile of tarfiles (one for each class). + logger.debug("Extracting the ImageNet train archive using Olexa's tar magic in python form...") + + # Save a small text file or something that tells us which subdirs are + # done extracting so we can just skip ahead to the right directory? + previously_extracted_dirs: set[str] = set() + if previously_extracted_dirs_file.exists(): + previously_extracted_dirs = set( + stripped_line + for line in previously_extracted_dirs_file.read_text().splitlines() + if (stripped_line := line.strip()) + ) + logger.debug( + f"{len(previously_extracted_dirs)} directories have already been fully extracted." + ) + previously_extracted_dirs_file.write_text( + "\n".join(sorted(previously_extracted_dirs)) + "\n" + ) + + if len(previously_extracted_dirs) == 1000: + logger.info("Train archive already fully extracted. Skipping.") + return + + with tarfile.open(train_archive, mode="r") as train_tarfile: + for class_id, member in enumerate( + tqdm.tqdm( + train_tarfile, + total=1000, # hard-coded here, since we know there are 1000 folders. + desc="Extracting train archive", + unit="Directories", + position=0, + ) + ): + if member.name in previously_extracted_dirs: + continue + + buffer = train_tarfile.extractfile(member) + assert buffer is not None + + class_subdir = train_dir / member.name.replace(".tar", "") + class_subdir_existed = class_subdir.exists() + if class_subdir_existed: + files_in_subdir = set(p.name for p in class_subdir.iterdir()) + else: + class_subdir.mkdir(parents=True, exist_ok=True) + files_in_subdir = set() + + with tarfile.open(fileobj=buffer, mode="r|*") as sub_tarfile: + for tarinfo in sub_tarfile: + image_file_path = class_subdir / tarinfo.name + if files_in_subdir and image_file_path.name in files_in_subdir: + # Image file is already in the directory. + continue + sub_tarfile.extract(tarinfo, class_subdir) + + # Alternative: .extractall with a list of members to extract: + # members = sub_tarfile.getmembers() # note: loads the full archive. + # if not files_in_subdir: + # members_to_extract = members + # else: + # members_to_extract = [m for m in members if m.name not in files_in_subdir] + # if members_to_extract: + # sub_tarfile.extractall(subdir, members=members_to_extract, filter="data") + + assert member.name not in previously_extracted_dirs + previously_extracted_dirs.add(member.name) + with previously_extracted_dirs_file.open("a") as f: + f.write(f"{member.name}\n") + + +def main(): + logging.basicConfig( + level=logging.DEBUG, format="%(message)s", handlers=[rich.logging.RichHandler()] + ) + slurm_tmpdir = Path(os.environ["SLURM_TMPDIR"]) + datamodule = ImageNetDataModule(slurm_tmpdir) + start = time.time() + datamodule.prepare_data() + datamodule.setup("fit") + dl = datamodule.train_dataloader() + _batch = next(iter(dl)) + end = time.time() + print(f"Prepared imagenet in {end-start:.2f}s.") + + +if __name__ == "__main__": + main() diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index 9b54144e..9ec9bc22 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -6,7 +6,7 @@ from collections.abc import Callable from logging import getLogger as get_logger from pathlib import Path -from typing import Any, ClassVar, Concatenate +from typing import ClassVar, Concatenate import torch from lightning import LightningDataModule @@ -81,7 +81,7 @@ def __init__( super().__init__() from project.configs.datamodule import DATA_DIR - self.data_dir = data_dir if data_dir is not None else DATA_DIR + self.data_dir: Path = Path(data_dir or DATA_DIR) self.val_split = val_split if num_workers is None: num_workers = num_cpus_on_node() @@ -93,15 +93,23 @@ def __init__( self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last - self._train_transforms = train_transforms - self._val_transforms = val_transforms - self._test_transforms = test_transforms + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.test_transforms = test_transforms self.EXTRA_ARGS = kwargs - self.train_kwargs = self.EXTRA_ARGS.copy() - self.test_kwargs = self.EXTRA_ARGS.copy() + self.train_kwargs = self.EXTRA_ARGS | { + "transform": self.train_transforms or self.default_transforms() + } + self.valid_kwargs = self.EXTRA_ARGS | { + "transform": self.val_transforms or self.default_transforms() + } + self.test_kwargs = self.EXTRA_ARGS | { + "transform": self.test_transforms or self.default_transforms() + } if _has_constructor_argument(self.dataset_cls, "train"): self.train_kwargs["train"] = True + self.valid_kwargs["train"] = True self.test_kwargs["train"] = False _rng = torch.Generator(device="cpu").manual_seed(self.seed) @@ -109,35 +117,11 @@ def __init__( self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) self.test_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) - self.dataset_test: VisionDataset | None = None - - @property - def train_transforms(self) -> Callable[..., Any] | None: - """Optional transforms (or collection of transforms) you can apply to train dataset.""" - return self._train_transforms - - @train_transforms.setter - def train_transforms(self, t: Callable) -> None: - self._train_transforms = t - - @property - def val_transforms(self) -> Callable[..., Any] | None: - """Optional transforms (or collection of transforms) you can apply to validation - dataset.""" - return self._val_transforms + self.test_dataset_cls = self.dataset_cls - @val_transforms.setter - def val_transforms(self, t: Callable) -> None: - self._val_transforms = t - - @property - def test_transforms(self) -> Callable[..., Any] | None: - """Optional transforms (or collection of transforms) you can apply to test dataset.""" - return self._test_transforms - - @test_transforms.setter - def test_transforms(self, t: Callable) -> None: - self._test_transforms = t + self.dataset_train: Dataset | None = None + self.dataset_val: Dataset | None = None + self.dataset_test: VisionDataset | None = None def prepare_data(self) -> None: """Saves files to data_dir.""" @@ -156,46 +140,30 @@ def prepare_data(self) -> None: logger.info( f"Preparing {self.name} dataset test spit in {self.data_dir} with {test_kwargs=}" ) - self.dataset_cls(str(self.data_dir), **test_kwargs) + self.test_dataset_cls(str(self.data_dir), **test_kwargs) def setup(self, stage: StageStr | None = None) -> None: """Creates train, val, and test dataset.""" if stage in ["fit", "validate"] or stage is None: - train_transforms = ( - self.default_transforms() - if self.train_transforms is None - else self.train_transforms - ) - val_transforms = ( - self.default_transforms() if self.val_transforms is None else self.val_transforms - ) - + logger.debug(f"creating training dataset with kwargs {self.train_kwargs}") dataset_train = self.dataset_cls( str(self.data_dir), - transform=train_transforms, **self.train_kwargs, ) - # dataset_train = wrap_dataset_for_transforms_v2(dataset_train) + logger.debug(f"creating validation dataset with kwargs {self.train_kwargs}") dataset_val = self.dataset_cls( str(self.data_dir), - transform=val_transforms, - **self.train_kwargs, # todo: Assuming those are the same for now. + **self.valid_kwargs, ) - # dataset_val = wrap_dataset_for_transforms_v2(dataset_val) - - # Split + # Train/validation split. + # NOTE: the dataset is created twice (with the right transforms) and split in the same + # way, such that there is no overlap in indices between train and validation sets. self.dataset_train = self._split_dataset(dataset_train, train=True) self.dataset_val = self._split_dataset(dataset_val, train=False) if stage == "test" or stage is None: - test_transforms = ( - self.default_transforms() if self.test_transforms is None else self.test_transforms - ) - dataset_test = self.dataset_cls( - str(self.data_dir), transform=test_transforms, **self.test_kwargs - ) - # dataset_test = wrap_dataset_for_transforms_v2(dataset_test) - self.dataset_test = dataset_test + logger.debug(f"creating test dataset with kwargs {self.train_kwargs}") + self.dataset_test = self.test_dataset_cls(str(self.data_dir), **self.test_kwargs) def _split_dataset(self, dataset: VisionDataset, train: bool = True) -> Dataset: """Splits the dataset into train and validation set.""" diff --git a/project/datamodules/vision/imagenet.py b/project/datamodules/vision/imagenet.py deleted file mode 100644 index 2ea7f423..00000000 --- a/project/datamodules/vision/imagenet.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import annotations - -import contextlib -import os -import shutil -import tarfile -import time -from collections.abc import Callable -from logging import getLogger as get_logger -from pathlib import Path -from typing import ClassVar, Concatenate, Literal, TypeVar - -import torch -import tqdm -from torchvision.datasets import ImageNet - -from project.configs.datamodule import DATA_DIR -from project.datamodules.vision.base import VisionDataModule -from project.utils.types import C, H, StageStr, W -from project.utils.types.protocols import Module - -logger = get_logger(__name__) -ImageNetType = TypeVar("ImageNetType", bound=ImageNet) - - -@contextlib.contextmanager -def change_directory(path: Path): - curdir = Path.cwd() - os.chdir(path) - yield - os.chdir(curdir) - - -class ImageNetDataModule(VisionDataModule): - name: ClassVar[str] = "imagenet" - """Dataset name.""" - - dataset_cls: ClassVar[type[ImageNet]] = ImageNet - """Dataset class to use.""" - - dims: tuple[C, H, W] = (C(3), H(224), W(224)) - """A tuple describing the shape of the data.""" - - num_classes: ClassVar[int] = 1000 - - def __init__( - self, - root: str | Path = DATA_DIR, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.root = Path(root) - - def prepare_data(self) -> None: - network_imagenet_dir = Path("/network/datasets/imagenet") - assert network_imagenet_dir.exists() - prepare_imagenet( - self.root, network_imagenet_dir=network_imagenet_dir, split="train", **self.EXTRA_ARGS - ) - prepare_imagenet( - self.root, network_imagenet_dir=network_imagenet_dir, split="val", **self.EXTRA_ARGS - ) - - def setup(self, stage: StageStr | None = None) -> None: - super().setup(stage) - - def default_transforms(self) -> Module[[torch.Tensor], torch.Tensor]: - from torchvision.models.resnet import ResNet152_Weights - - return ResNet152_Weights.IMAGENET1K_V1.transforms - - -def prepare_imagenet[**P]( - root: str | Path, - split: Literal["train", "val"] = "train", - network_imagenet_dir: Path = Path("/network/datasets/imagenet"), - _dataset: Callable[Concatenate[str, Literal["train", "val"], P], ImageNet] = ImageNet, - *args: P.args, - **kwargs: P.kwargs, -) -> ImageNet: - """Custom preparation function for ImageNet, using @obilaniu's tar magic in Python form. - - The core of this is equivalent to these bash commands: - - ```bash - mkdir -p $SLURM_TMPDIR/imagenet/val - cd $SLURM_TMPDIR/imagenet/val - tar -xf /network/scratch/b/bilaniuo/ILSVRC2012_img_val.tar - mkdir -p $SLURM_TMPDIR/imagenet/train - cd $SLURM_TMPDIR/imagenet/train - tar -xf /network/datasets/imagenet/ILSVRC2012_img_train.tar \ - --to-command='mkdir ${TAR_REALNAME%.tar}; tar -xC ${TAR_REALNAME%.tar}' - ``` - """ - root = Path(root) - if not network_imagenet_dir.exists(): - raise NotImplementedError( - f"Assuming that we're running on the Mila cluster where {network_imagenet_dir} exists for now." - ) - val_archive_file_name = "ILSVRC2012_img_val.tar" - train_archive_file_name = "ILSVRC2012_img_train.tar" - devkit_file_name = "ILSVRC2012_devkit_t12.tar.gz" - md5sums_file_name = "md5sums" - - def _symlink_if_needed(filename: str, network_imagenet_dir: Path): - symlink = root / filename - if not symlink.exists(): - symlink.symlink_to(network_imagenet_dir / filename) - - # Create a symlink to the archive in $SLURM_TMPDIR, because torchvision expects it to be - # there. - _symlink_if_needed(train_archive_file_name, network_imagenet_dir) - _symlink_if_needed(val_archive_file_name, network_imagenet_dir) - _symlink_if_needed(devkit_file_name, network_imagenet_dir) - # TODO: COPY the file, not symlink it! (otherwise we get some "Read-only filesystem" errors - # when calling tvd.ImageNet(...). (Probably because the constructor tries to open the file) - # _symlink_if_needed(md5sums_file_name, network_imagenet_dir) - md5sums_file = root / md5sums_file_name - if not md5sums_file.exists(): - shutil.copyfile(network_imagenet_dir / md5sums_file_name, md5sums_file) - md5sums_file.chmod(0o755) - - logger.info("Extracting the ImageNet archives using Olexa's tar magic in python form...") - - if split == "train": - train_dir = root / "train" - train_dir.mkdir(exist_ok=True, parents=True) - - # The ImageNet train archive is a tarfile of tarfiles (one for each class). - with tarfile.open(network_imagenet_dir / train_archive_file_name) as train_tarfile: - for member in tqdm.tqdm( - train_tarfile, - total=1000, # hard-coded here, since we know there are 1000 folders. - desc="Extracting train archive", - unit="Directories", - position=0, - ): - buffer = train_tarfile.extractfile(member) - assert buffer is not None - subdir = train_dir / member.name.replace(".tar", "") - subdir.mkdir(mode=0o755, parents=True, exist_ok=True) - files_in_subdir = set(p.name for p in subdir.iterdir()) - with tarfile.open(fileobj=buffer, mode="r|*") as sub_tarfile: - for tarinfo in sub_tarfile: - if tarinfo.name in files_in_subdir: - # Image file is already in the directory. - continue - sub_tarfile.extract(tarinfo, subdir) - - else: - val_dir = root / "val" - val_dir.mkdir(exist_ok=True, parents=True) - with tarfile.open(network_imagenet_dir / val_archive_file_name) as val_tarfile: - val_tarfile.extractall(val_dir) - - return _dataset(str(root), split, *args, **kwargs) - - -def main(): - slurm_tmpdir = Path(os.environ["SLURM_TMPDIR"]) - datamodule = ImageNetDataModule(slurm_tmpdir) - start = time.time() - datamodule.prepare_data() - datamodule.setup("fit") - dl = datamodule.train_dataloader() - _batch = next(iter(dl)) - end = time.time() - print(f"Prepared imagenet in {end-start:.2f}s.") - - -if __name__ == "__main__": - main() From bb41a958830b5e9e32a7fc8842c0cbf42420a79a Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 19:43:29 +0000 Subject: [PATCH 05/27] Use .yaml files with structured base 4 datamodules Signed-off-by: Fabrice Normandin --- project/configs/__init__.py | 4 +- project/configs/datamodule/__init__.py | 135 ++---------------- project/configs/datamodule/cifar10.yaml | 6 + project/configs/datamodule/fashion_mnist.yaml | 3 + project/configs/datamodule/imagenet.yaml | 4 + project/configs/datamodule/imagenet32.yaml | 9 ++ project/configs/datamodule/inaturalist.yaml | 6 + project/configs/datamodule/mnist.yaml | 7 + project/datamodules/datamodules_test.py | 1 + .../datamodules/image_classification/base.py | 4 +- .../image_classification/fashion_mnist.py | 70 +-------- .../image_classification/inaturalist.py | 22 +-- project/datamodules/vision/base.py | 18 +-- project/utils/env_vars.py | 41 ++++++ project/utils/testutils.py | 3 +- 15 files changed, 104 insertions(+), 229 deletions(-) create mode 100644 project/configs/datamodule/cifar10.yaml create mode 100644 project/configs/datamodule/fashion_mnist.yaml create mode 100644 project/configs/datamodule/imagenet.yaml create mode 100644 project/configs/datamodule/imagenet32.yaml create mode 100644 project/configs/datamodule/inaturalist.yaml create mode 100644 project/configs/datamodule/mnist.yaml create mode 100644 project/utils/env_vars.py diff --git a/project/configs/__init__.py b/project/configs/__init__.py index bb734082..13e6a8d7 100644 --- a/project/configs/__init__.py +++ b/project/configs/__init__.py @@ -2,11 +2,9 @@ from hydra.core.config_store import ConfigStore +from ..utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR from .config import Config from .datamodule import ( - REPO_ROOTDIR, - SLURM_JOB_ID, - SLURM_TMPDIR, datamodule_store, ) from .network import network_store diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index 541a35e1..ac98a4fb 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -1,86 +1,31 @@ -import os -from collections.abc import Callable -from dataclasses import dataclass, field from logging import getLogger as get_logger from pathlib import Path -import torch from hydra_zen import hydrated_dataclass, instantiate, store -from torch import Tensor from project.datamodules import ( - CIFAR10DataModule, - FashionMNISTDataModule, - ImageNet32DataModule, - MNISTDataModule, VisionDataModule, ) -from project.datamodules.image_classification.cifar10 import cifar10_train_transforms -from project.datamodules.image_classification.imagenet32 import imagenet32_train_transforms -from project.datamodules.image_classification.inaturalist import ( - INaturalistDataModule, - TargetType, - Version, -) -from project.datamodules.image_classification.mnist import mnist_train_transforms -from project.datamodules.vision.base import SLURM_TMPDIR - -FILE = Path(__file__) -REPO_ROOTDIR = FILE.parent -for level in range(5): - if "README.md" in list(p.name for p in REPO_ROOTDIR.iterdir()): - break - REPO_ROOTDIR = REPO_ROOTDIR.parent - -SLURM_JOB_ID: int | None = ( - int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None -) +from project.utils.env_vars import DATA_DIR, NETWORK_DIR, NUM_WORKERS logger = get_logger(__name__) - -TORCHVISION_DIR: Path | None = None - -_torchvision_dir = Path("/network/datasets/torchvision") -if _torchvision_dir.exists() and _torchvision_dir.is_dir(): - TORCHVISION_DIR = _torchvision_dir - - -if not SLURM_TMPDIR and SLURM_JOB_ID is not None: - # This can happens when running the integrated VSCode terminal with `mila code`! - _slurm_tmpdir = Path(f"/Tmp/slurm.{SLURM_JOB_ID}.0") - if _slurm_tmpdir.exists(): - SLURM_TMPDIR = _slurm_tmpdir -SCRATCH = Path(os.environ["SCRATCH"]) if "SCRATCH" in os.environ else None -DATA_DIR = Path(os.environ.get("DATA_DIR", (SLURM_TMPDIR or SCRATCH or REPO_ROOTDIR) / "data")) - -NUM_WORKERS = int( - os.environ.get( - "SLURM_CPUS_PER_TASK", - os.environ.get( - "SLURM_CPUS_ON_NODE", - len(os.sched_getaffinity(0)) - if hasattr(os, "sched_getaffinity") - else torch.multiprocessing.cpu_count(), - ), - ) -) -logger = get_logger(__name__) - - -Transform = Callable[[Tensor], Tensor] - - -@dataclass -class DataModuleConfig: ... +torchvision_dir: Path | None = None +"""Network directory with torchvision datasets.""" +if ( + NETWORK_DIR + and (_torchvision_dir := NETWORK_DIR / "datasets/torchvision").exists() + and _torchvision_dir.is_dir() +): + torchvision_dir = _torchvision_dir datamodule_store = store(group="datamodule") @hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) -class VisionDataModuleConfig(DataModuleConfig): - data_dir: str | None = str(TORCHVISION_DIR or DATA_DIR) +class VisionDataModuleConfig: + data_dir: str | None = str(torchvision_dir or DATA_DIR) val_split: int | float = 0.1 # NOTE: reduced from default of 0.2 num_workers: int = NUM_WORKERS normalize: bool = True # NOTE: Set to True by default instead of False @@ -93,60 +38,4 @@ class VisionDataModuleConfig(DataModuleConfig): __call__ = instantiate -# todo: look into this to avoid having to make dataclasses with no fields just to call a function.. -from hydra_zen import store, zen # noqa - - -# FIXME: This is dumb! -@hydrated_dataclass(target=mnist_train_transforms) -class MNISTTrainTransforms: ... - - -@hydrated_dataclass(target=MNISTDataModule, populate_full_signature=True) -class MNISTDataModuleConfig(VisionDataModuleConfig): - normalize: bool = True - batch_size: int = 128 - train_transforms: MNISTTrainTransforms = field(default_factory=MNISTTrainTransforms) - - -@hydrated_dataclass(target=FashionMNISTDataModule, populate_full_signature=True) -class FashionMNISTDataModuleConfig(MNISTDataModuleConfig): ... - - -@hydrated_dataclass(target=cifar10_train_transforms) -class Cifar10TrainTransforms: ... - - -@hydrated_dataclass(target=CIFAR10DataModule, populate_full_signature=True) -class CIFAR10DataModuleConfig(VisionDataModuleConfig): - train_transforms: Cifar10TrainTransforms = field(default_factory=Cifar10TrainTransforms) - # Overwriting this one: - batch_size: int = 128 - - -@hydrated_dataclass(target=imagenet32_train_transforms) -class ImageNet32TrainTransforms: ... - - -@hydrated_dataclass(target=ImageNet32DataModule, populate_full_signature=True) -class ImageNet32DataModuleConfig(VisionDataModuleConfig): - data_dir: Path = ((SCRATCH / "data") if SCRATCH else DATA_DIR) / "imagenet32" - - val_split: int | float = -1 - num_images_per_val_class: int = 50 # Slightly different. - normalize: bool = True - train_transforms: ImageNet32TrainTransforms = field(default_factory=ImageNet32TrainTransforms) - - -@hydrated_dataclass(target=INaturalistDataModule, populate_full_signature=True) -class INaturalistDataModuleConfig(VisionDataModuleConfig): - data_dir: Path | None = None - version: Version = "2021_train" - target_type: TargetType | list[TargetType] = "full" - - -datamodule_store(CIFAR10DataModuleConfig, name="cifar10") -datamodule_store(MNISTDataModuleConfig, name="mnist") -datamodule_store(FashionMNISTDataModuleConfig, name="fashion_mnist") -datamodule_store(ImageNet32DataModuleConfig, name="imagenet32") -datamodule_store(INaturalistDataModuleConfig, name="inaturalist") +datamodule_store(VisionDataModuleConfig, name="vision") diff --git a/project/configs/datamodule/cifar10.yaml b/project/configs/datamodule/cifar10.yaml new file mode 100644 index 00000000..e8d3fb78 --- /dev/null +++ b/project/configs/datamodule/cifar10.yaml @@ -0,0 +1,6 @@ +defaults: +- vision +_target_: project.datamodules.CIFAR10DataModule +batch_size: 128 +train_transforms: + _target_: project.datamodules.image_classification.cifar10.cifar10_train_transforms diff --git a/project/configs/datamodule/fashion_mnist.yaml b/project/configs/datamodule/fashion_mnist.yaml new file mode 100644 index 00000000..f2038d2d --- /dev/null +++ b/project/configs/datamodule/fashion_mnist.yaml @@ -0,0 +1,3 @@ +defaults: +- mnist +_target_: project.datamodules.FashionMNISTDataModule diff --git a/project/configs/datamodule/imagenet.yaml b/project/configs/datamodule/imagenet.yaml new file mode 100644 index 00000000..78ae4c6d --- /dev/null +++ b/project/configs/datamodule/imagenet.yaml @@ -0,0 +1,4 @@ +defaults: +- vision +_target_: project.datamodules.ImageNetDataModule +# todo: add good configuration options here. diff --git a/project/configs/datamodule/imagenet32.yaml b/project/configs/datamodule/imagenet32.yaml new file mode 100644 index 00000000..876a037c --- /dev/null +++ b/project/configs/datamodule/imagenet32.yaml @@ -0,0 +1,9 @@ +defaults: +- vision +_target_: project.datamodules.ImageNet32DataModule +data_dir: "${oc.env:SCRATCH}/data" +val_split: -1 +num_images_per_val_class: 50 # Slightly different. +normalize: True +train_transforms: + _target_: project.datamodules.image_classification.imagenet32.imagenet32_train_transforms diff --git a/project/configs/datamodule/inaturalist.yaml b/project/configs/datamodule/inaturalist.yaml new file mode 100644 index 00000000..5cea87f5 --- /dev/null +++ b/project/configs/datamodule/inaturalist.yaml @@ -0,0 +1,6 @@ +defaults: +- vision +_target_: project.datamodules.INaturalistDataModule +data_dir: null +version: "2021_train" +target_type: "full" diff --git a/project/configs/datamodule/mnist.yaml b/project/configs/datamodule/mnist.yaml new file mode 100644 index 00000000..c9a16639 --- /dev/null +++ b/project/configs/datamodule/mnist.yaml @@ -0,0 +1,7 @@ +defaults: +- vision +_target_: project.datamodules.MNISTDataModule +normalize: True +batch_size: 128 +train_transforms: + _target_: project.datamodules.image_classification.mnist.mnist_train_transforms diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index 432a4f7f..5667990e 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -21,6 +21,7 @@ def test_first_batch( original_datadir: Path, datadir: Path, ): + # todo: skip this test if the dataset isn't already downloaded (for example on the GitHub CI). datamodule.prepare_data() datamodule.setup("fit") diff --git a/project/datamodules/image_classification/base.py b/project/datamodules/image_classification/base.py index 0741e021..8b392437 100644 --- a/project/datamodules/image_classification/base.py +++ b/project/datamodules/image_classification/base.py @@ -5,9 +5,11 @@ from project.datamodules.vision.base import VisionDataModule from project.utils.types import C, H, W +# todo: decide if this should be a protocol or an actual base class (currently a base class). + class ImageClassificationDataModule[BatchType: tuple[Tensor, Tensor]](VisionDataModule[BatchType]): - """Protocol that describes lightning data modules for image classification.""" + """Lightning data modules for image classification.""" num_classes: int """Number of classes in the dataset.""" diff --git a/project/datamodules/image_classification/fashion_mnist.py b/project/datamodules/image_classification/fashion_mnist.py index df42a784..8b8c080d 100644 --- a/project/datamodules/image_classification/fashion_mnist.py +++ b/project/datamodules/image_classification/fashion_mnist.py @@ -1,17 +1,11 @@ from __future__ import annotations -from collections.abc import Callable -from typing import Any - -import torch from torchvision.datasets import FashionMNIST -from torchvision.transforms import v2 as transform_lib -from project.datamodules.image_classification.base import ImageClassificationDataModule -from project.utils.types import C, H, W +from project.datamodules.image_classification.mnist import MNISTDataModule -class FashionMNISTDataModule(ImageClassificationDataModule): +class FashionMNISTDataModule(MNISTDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png @@ -42,63 +36,3 @@ class FashionMNISTDataModule(ImageClassificationDataModule): name = "fashion_mnist" dataset_cls = FashionMNIST - dims = (C(1), H(28), W(28)) - num_classes = 10 - - def __init__( - self, - data_dir: str | None = None, - val_split: int | float = 0.2, - num_workers: int | None = 0, - normalize: bool = False, - batch_size: int = 32, - seed: int = 42, - shuffle: bool = True, - pin_memory: bool = True, - drop_last: bool = False, - *args: Any, - **kwargs: Any, - ) -> None: - """ - - Args: - data_dir: Root directory of dataset. - val_split: Percent (float) or number (int) of samples to use for the validation split. - num_workers: Number of workers to use for loading data. - normalize: If ``True``, applies image normalization. - batch_size: Number of samples per batch to load. - seed: Random seed to be used for train/val/test splits. - shuffle: If ``True``, shuffles the train data every epoch. - pin_memory: If ``True``, the data loader will copy Tensors into CUDA pinned memory \ - before returning them. - drop_last: If ``True``, drops the last incomplete batch. - """ - super().__init__( - data_dir=data_dir, - val_split=val_split, - num_workers=num_workers, - normalize=normalize, - batch_size=batch_size, - seed=seed, - shuffle=shuffle, - pin_memory=pin_memory, - drop_last=drop_last, - *args, - **kwargs, - ) - self.prepare_data() - self.setup("fit") - - def default_transforms(self) -> Callable: - if self.normalize: - mnist_transforms = transform_lib.Compose( - [ - transform_lib.ToImage(), - transform_lib.ToDtype(torch.float32, scale=True), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ] - ) - else: - mnist_transforms = transform_lib.Compose([transform_lib.ToImage()]) - - return mnist_transforms diff --git a/project/datamodules/image_classification/inaturalist.py b/project/datamodules/image_classification/inaturalist.py index fffff33c..467f1eaa 100644 --- a/project/datamodules/image_classification/inaturalist.py +++ b/project/datamodules/image_classification/inaturalist.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import warnings from collections.abc import Callable from logging import getLogger as get_logger @@ -11,6 +10,7 @@ from torchvision.datasets import INaturalist from project.datamodules.image_classification.base import ImageClassificationDataModule +from project.datamodules.vision.base import SLURM_TMPDIR from project.utils.types import C, H, W logger = get_logger(__name__) @@ -25,23 +25,6 @@ Version = Version2017_2019 | Version2021 -def get_slurm_tmpdir() -> Path: - if "SLURM_TMPDIR" in os.environ: - return Path(os.environ["SLURM_TMPDIR"]) - if "SLURM_JOB_ID" not in os.environ: - raise RuntimeError( - "SLURM_JOBID environment variable isn't set. Are you running this from a SLURM " - "cluster?" - ) - slurm_tmpdir = Path(f"/Tmp/slurm.{os.environ['SLURM_JOB_ID']}.0") - if not slurm_tmpdir.is_dir(): - raise NotImplementedError( - f"TODO: You appear to be running this outside the Mila cluster, since SLURM_TMPDIR " - f"isn't located at {slurm_tmpdir}." - ) - return slurm_tmpdir - - def inat_dataset_dir() -> Path: network_dir = Path("/network/datasets/inat") if not network_dir.exists(): @@ -79,7 +62,8 @@ def __init__( ) -> None: # assuming that we're on the Mila cluster atm. self.network_dir = inat_dataset_dir() - slurm_tmpdir = get_slurm_tmpdir() + assert SLURM_TMPDIR, "assuming that we're on a compute node." + slurm_tmpdir = SLURM_TMPDIR default_data_dir = slurm_tmpdir / "data" if data_dir is None: data_dir = default_data_dir diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index 9ec9bc22..7dbf5f24 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -12,21 +12,11 @@ from lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision.datasets import VisionDataset -from typing_extensions import ParamSpec from project.utils.types import C, H, StageStr, W from ...utils.types.protocols import DataModule -P = ParamSpec("P") - -SLURM_TMPDIR: Path | None = ( - Path(os.environ["SLURM_TMPDIR"]) - if "SLURM_TMPDIR" in os.environ - else tmp - if "SLURM_JOB_ID" in os.environ and (tmp := Path("/tmp")).exists() - else None -) logger = get_logger(__name__) @@ -195,7 +185,7 @@ def _get_splits(self, len_dataset: int) -> list[int]: def default_transforms(self) -> Callable: """Default transform for the dataset.""" - def train_dataloader( + def train_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -215,7 +205,7 @@ def train_dataloader( ), ) - def val_dataloader( + def val_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -230,7 +220,7 @@ def val_dataloader( **(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs), ) - def test_dataloader( + def test_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -247,7 +237,7 @@ def test_dataloader( **(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs), ) - def _data_loader( + def _data_loader[**P]( self, dataset: Dataset, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, diff --git a/project/utils/env_vars.py b/project/utils/env_vars.py new file mode 100644 index 00000000..6c4b5c68 --- /dev/null +++ b/project/utils/env_vars.py @@ -0,0 +1,41 @@ +import os +from pathlib import Path + +import torch + +SLURM_TMPDIR: Path | None = ( + Path(os.environ["SLURM_TMPDIR"]) + if "SLURM_TMPDIR" in os.environ + else tmp + if "SLURM_JOB_ID" in os.environ and (tmp := Path("/tmp")).exists() + else None +) +SLURM_JOB_ID: int | None = ( + int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None +) +NETWORK_DIR = Path(os.environ["NETWORK_DIR"]) if "NETWORK_DIR" in os.environ else None + +REPO_ROOTDIR = Path(__file__).parent +for level in range(5): + if "README.md" in list(p.name for p in REPO_ROOTDIR.iterdir()): + break + REPO_ROOTDIR = REPO_ROOTDIR.parent + +SCRATCH = Path(os.environ["SCRATCH"]) if "SCRATCH" in os.environ else None +"""SCRATCH directory where logs / checkpoints / custom datasets should be saved.""" + +DATA_DIR = Path(os.environ.get("DATA_DIR", (SLURM_TMPDIR or SCRATCH or REPO_ROOTDIR) / "data")) +"""Directory where datasets should be extracted.""" + + +NUM_WORKERS = int( + os.environ.get( + "SLURM_CPUS_PER_TASK", + os.environ.get( + "SLURM_CPUS_ON_NODE", + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else torch.multiprocessing.cpu_count(), + ), + ) +) diff --git a/project/utils/testutils.py b/project/utils/testutils.py index bfe9a203..2c2de328 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -24,12 +24,13 @@ from torch.optim import Optimizer from project.configs import Config, cs -from project.configs.datamodule import DATA_DIR, SLURM_JOB_ID +from project.configs.datamodule import DATA_DIR from project.datamodules.image_classification import ( ImageClassificationDataModule, ) from project.datamodules.vision.base import VisionDataModule from project.experiment import instantiate_trainer +from project.utils.env_vars import SLURM_JOB_ID from project.utils.hydra_utils import get_attr, get_outer_class from project.utils.types import PhaseStr from project.utils.types.protocols import DataModule From 3bd2ed481112d3250e3a99ddbb5dca7b04b5e7fc Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 19:54:00 +0000 Subject: [PATCH 06/27] Fix typo in log, change preparation benchmark dir Signed-off-by: Fabrice Normandin --- project/datamodules/image_classification/imagenet.py | 9 +++++---- project/datamodules/vision/base.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py index c4297495..ce1f76b9 100644 --- a/project/datamodules/image_classification/imagenet.py +++ b/project/datamodules/image_classification/imagenet.py @@ -201,7 +201,8 @@ def train_transform(self) -> Module[[torch.Tensor], torch.Tensor]: [ transform_lib.RandomResizedCrop(self.image_size), transform_lib.RandomHorizontalFlip(), - transform_lib.ToTensor(), + transform_lib.ToImage(), + transform_lib.ToDtype(torch.float32, scale=True), imagenet_normalization(), ] ) @@ -226,7 +227,8 @@ def val_transform(self) -> Callable: [ transform_lib.Resize(self.image_size + 32), transform_lib.CenterCrop(self.image_size), - transform_lib.ToTensor(), + transform_lib.ToImage(), + transform_lib.ToDtype(torch.float32, scale=True), imagenet_normalization(), ] ) @@ -407,8 +409,7 @@ def main(): logging.basicConfig( level=logging.DEBUG, format="%(message)s", handlers=[rich.logging.RichHandler()] ) - slurm_tmpdir = Path(os.environ["SLURM_TMPDIR"]) - datamodule = ImageNetDataModule(slurm_tmpdir) + datamodule = ImageNetDataModule() start = time.time() datamodule.prepare_data() datamodule.setup("fit") diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index 7dbf5f24..dc083d46 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -140,7 +140,7 @@ def setup(self, stage: StageStr | None = None) -> None: str(self.data_dir), **self.train_kwargs, ) - logger.debug(f"creating validation dataset with kwargs {self.train_kwargs}") + logger.debug(f"creating validation dataset with kwargs {self.valid_kwargs}") dataset_val = self.dataset_cls( str(self.data_dir), **self.valid_kwargs, From 0bbdcab6a8aa6564ccb03a71775f222ac9c09aba Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 20:03:15 +0000 Subject: [PATCH 07/27] Reduce logging verbosity a bit Signed-off-by: Fabrice Normandin --- project/datamodules/vision/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index dc083d46..ae251fef 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -122,12 +122,12 @@ def prepare_data(self) -> None: if _has_constructor_argument(self.dataset_cls, "download"): train_kwargs["download"] = True test_kwargs["download"] = True - logger.info( + logger.debug( f"Preparing {self.name} dataset training split in {self.data_dir} with {train_kwargs}" ) self.dataset_cls(str(self.data_dir), **train_kwargs) if test_kwargs != train_kwargs: - logger.info( + logger.debug( f"Preparing {self.name} dataset test spit in {self.data_dir} with {test_kwargs=}" ) self.test_dataset_cls(str(self.data_dir), **test_kwargs) From 55592c8741dfc47cb7c3474acf0f338417f242b6 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 20:03:31 +0000 Subject: [PATCH 08/27] Temporarily add a project cmd -> project/main.py Signed-off-by: Fabrice Normandin --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1443b0ff..15b6ba14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,9 @@ requires-python = ">=3.12" readme = "README.md" license = {text = "MIT"} +[project.scripts] +project = "project:main.main" + [tool.setuptools] packages = ["project"] From d0802a59a26c0de82ed36be352bf86d16566e2bd Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 20:08:33 +0000 Subject: [PATCH 09/27] Fix tiny bugs and import errors Signed-off-by: Fabrice Normandin --- .../image_classification/imagenet32_test.py | 2 +- .../image_classification/inaturalist.py | 2 +- .../datamodules/image_classification/mnist.py | 6 +--- .../image_classification/transforms.py | 31 ------------------- project/main_test.py | 6 ++-- 5 files changed, 7 insertions(+), 40 deletions(-) delete mode 100644 project/datamodules/image_classification/transforms.py diff --git a/project/datamodules/image_classification/imagenet32_test.py b/project/datamodules/image_classification/imagenet32_test.py index 771978cc..0478fbfe 100644 --- a/project/datamodules/image_classification/imagenet32_test.py +++ b/project/datamodules/image_classification/imagenet32_test.py @@ -3,7 +3,7 @@ import pytest -from project.configs.datamodule import SCRATCH +from project.utils.env_vars import SCRATCH from .imagenet32 import ImageNet32DataModule diff --git a/project/datamodules/image_classification/inaturalist.py b/project/datamodules/image_classification/inaturalist.py index 467f1eaa..9bf192cd 100644 --- a/project/datamodules/image_classification/inaturalist.py +++ b/project/datamodules/image_classification/inaturalist.py @@ -10,7 +10,7 @@ from torchvision.datasets import INaturalist from project.datamodules.image_classification.base import ImageClassificationDataModule -from project.datamodules.vision.base import SLURM_TMPDIR +from project.utils.env_vars import SLURM_TMPDIR from project.utils.types import C, H, W logger = get_logger(__name__) diff --git a/project/datamodules/image_classification/mnist.py b/project/datamodules/image_classification/mnist.py index dc268aa2..905aa261 100644 --- a/project/datamodules/image_classification/mnist.py +++ b/project/datamodules/image_classification/mnist.py @@ -27,11 +27,7 @@ def mnist_train_transforms(): def mnist_normalization(): # NOTE: Taken from https://stackoverflow.com/a/67233938/6388696 # return transforms.Normalize(mean=0.5, std=0.5) - return transforms.Compose( - [ - transforms.Normalize(mean=[0.1307], std=[0.3081]), - ] - ) + return transforms.Normalize(mean=[0.1307], std=[0.3081]) def mnist_unnormalization(x: Tensor) -> Tensor: diff --git a/project/datamodules/image_classification/transforms.py b/project/datamodules/image_classification/transforms.py deleted file mode 100644 index 283541d4..00000000 --- a/project/datamodules/image_classification/transforms.py +++ /dev/null @@ -1,31 +0,0 @@ -from collections.abc import Callable - -from torchvision import transforms - - -def imagenet_normalization() -> Callable: - return transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) - - -def cifar10_normalization() -> Callable: - return transforms.Normalize( - mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], - std=[x / 255.0 for x in [63.0, 62.1, 66.7]], - ) - - -def stl10_normalization() -> Callable: - return transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27)) - - -def emnist_normalization(split: str) -> Callable: - # `stats` contains mean and std for each `split`. - stats = { - "balanced": (0.175, 0.333), - "byclass": (0.174, 0.332), - "bymerge": (0.174, 0.332), - "digits": (0.173, 0.332), - "letters": (0.172, 0.331), - "mnist": (0.173, 0.332), - } - return transforms.Normalize(mean=stats[split][0], std=stats[split][1]) diff --git a/project/main_test.py b/project/main_test.py index a855ba57..32d16e4b 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -9,7 +9,6 @@ from project.algorithms import Algorithm, ExampleAlgorithm from project.configs.config import Config -from project.configs.datamodule import CIFAR10DataModuleConfig from project.conftest import setup_hydra_for_tests_and_compose, use_overrides from project.datamodules.image_classification.cifar10 import CIFAR10DataModule from project.networks.fcnet import FcNet @@ -42,7 +41,10 @@ def set_testing_hydra_dir(): @use_overrides([""]) def test_defaults(experiment_config: Config) -> None: assert isinstance(experiment_config.algorithm, ExampleAlgorithm.HParams) - assert isinstance(experiment_config.datamodule, CIFAR10DataModuleConfig | CIFAR10DataModule) + assert ( + isinstance(experiment_config.datamodule, CIFAR10DataModule) + or hydra_zen.get_target(experiment_config.datamodule) is CIFAR10DataModule + ) def _ids(v): From c9fe73af3322fa64368e0cf1afbde4eb22900a10 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 17 Jun 2024 20:48:58 +0000 Subject: [PATCH 10/27] Add notes about extending structured configs Signed-off-by: Fabrice Normandin --- project/configs/datamodule/__init__.py | 5 ++++- project/configs/datamodule/imagenet32.yaml | 4 +++- .../trainer/callbacks/no_checkpoints.yaml | 15 ++++++++++----- project/datamodules/vision/base.py | 11 ++++------- project/experiment.py | 6 +++--- project/utils/env_vars.py | 9 ++++++++- project/utils/testutils.py | 17 +++++++++++++++-- 7 files changed, 47 insertions(+), 20 deletions(-) diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index ac98a4fb..61bd5077 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -20,7 +20,9 @@ torchvision_dir = _torchvision_dir -datamodule_store = store(group="datamodule") +# TODO: Make it possible to extend a structured base via yaml files as well as adding new fields +# (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the +# config). @hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) @@ -38,4 +40,5 @@ class VisionDataModuleConfig: __call__ = instantiate +datamodule_store = store(group="datamodule") datamodule_store(VisionDataModuleConfig, name="vision") diff --git a/project/configs/datamodule/imagenet32.yaml b/project/configs/datamodule/imagenet32.yaml index 876a037c..9c4bb7d7 100644 --- a/project/configs/datamodule/imagenet32.yaml +++ b/project/configs/datamodule/imagenet32.yaml @@ -3,7 +3,9 @@ defaults: _target_: project.datamodules.ImageNet32DataModule data_dir: "${oc.env:SCRATCH}/data" val_split: -1 -num_images_per_val_class: 50 # Slightly different. +# TODO: Can't currently add this key since it isn't in the structured config for the +# `VisionDataModule`. +# num_images_per_val_class: 50 normalize: True train_transforms: _target_: project.datamodules.image_classification.imagenet32.imagenet32_train_transforms diff --git a/project/configs/trainer/callbacks/no_checkpoints.yaml b/project/configs/trainer/callbacks/no_checkpoints.yaml index 49dbccf0..ee01c442 100644 --- a/project/configs/trainer/callbacks/no_checkpoints.yaml +++ b/project/configs/trainer/callbacks/no_checkpoints.yaml @@ -1,6 +1,11 @@ -model_summary: - _target_: lightning.pytorch.callbacks.RichModelSummary - max_depth: 1 +defaults: + - default -rich_progress_bar: - _target_: lightning.pytorch.callbacks.RichProgressBar +model_checkpoint: null + +# model_summary: +# _target_: lightning.pytorch.callbacks.RichModelSummary +# max_depth: 1 + +# rich_progress_bar: +# _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index ae251fef..b75d0a9d 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -13,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset, random_split from torchvision.datasets import VisionDataset +from project.utils.env_vars import DATA_DIR, NUM_WORKERS from project.utils.types import C, H, StageStr, W from ...utils.types.protocols import DataModule @@ -37,9 +38,9 @@ class VisionDataModule[BatchType_co](LightningDataModule, DataModule[BatchType_c def __init__( self, - data_dir: str | Path | None = None, + data_dir: str | Path = DATA_DIR, val_split: int | float = 0.2, - num_workers: int | None = None, + num_workers: int | None = NUM_WORKERS, normalize: bool = False, batch_size: int = 32, seed: int = 42, @@ -69,13 +70,8 @@ def __init__( """ super().__init__() - from project.configs.datamodule import DATA_DIR - self.data_dir: Path = Path(data_dir or DATA_DIR) self.val_split = val_split - if num_workers is None: - num_workers = num_cpus_on_node() - logger.debug(f"Setting the number of dataloader workers to {num_workers}.") self.num_workers = num_workers self.normalize = normalize self.batch_size = batch_size @@ -102,6 +98,7 @@ def __init__( self.valid_kwargs["train"] = True self.test_kwargs["train"] = False + # todo: what about the shuffling at each epoch? _rng = torch.Generator(device="cpu").manual_seed(self.seed) self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) diff --git a/project/experiment.py b/project/experiment.py index f4112356..d437a9b1 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -112,9 +112,9 @@ def instantiate_trainer(experiment_config: Config) -> Trainer: # fields have the right type. # instantiate all the callbacks - callbacks: dict[str, Callback] | None = hydra_zen.instantiate( - experiment_config.trainer.pop("callbacks", {}) - ) + callback_configs = experiment_config.trainer.pop("callbacks", {}) + callback_configs = {k: v for k, v in callback_configs.items() if v is not None} + callbacks: dict[str, Callback] | None = hydra_zen.instantiate(callback_configs) # Create the loggers, if any. loggers: dict[str, Any] | None = instantiate(experiment_config.trainer.pop("logger", {})) # Create the Trainer. diff --git a/project/utils/env_vars.py b/project/utils/env_vars.py index 6c4b5c68..78ebdb38 100644 --- a/project/utils/env_vars.py +++ b/project/utils/env_vars.py @@ -13,7 +13,14 @@ SLURM_JOB_ID: int | None = ( int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None ) -NETWORK_DIR = Path(os.environ["NETWORK_DIR"]) if "NETWORK_DIR" in os.environ else None + +NETWORK_DIR = ( + Path(os.environ["NETWORK_DIR"]) + if "NETWORK_DIR" in os.environ + else _network_dir + if (_network_dir := Path("/network")).exists() + else None +) REPO_ROOTDIR = Path(__file__).parent for level in range(5): diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 2c2de328..73fb02b7 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -6,6 +6,7 @@ import dataclasses import hashlib import importlib +import os from collections.abc import Mapping, Sequence from contextlib import contextmanager from logging import getLogger as get_logger @@ -30,13 +31,16 @@ ) from project.datamodules.vision.base import VisionDataModule from project.experiment import instantiate_trainer -from project.utils.env_vars import SLURM_JOB_ID +from project.utils.env_vars import NETWORK_DIR, SLURM_JOB_ID from project.utils.hydra_utils import get_attr, get_outer_class from project.utils.types import PhaseStr from project.utils.types.protocols import DataModule from project.utils.utils import get_device -SLOW_DATAMODULES = ["inaturalist", "imagenet32"] +on_github_ci = "GITHUB_ACTIONS" in os.environ +on_self_hosted_github_ci = on_github_ci and "self-hosted" in os.environ.get("RUNNER_LABELS", "") + +SLOW_DATAMODULES = ["inaturalist", "imagenet32", "imagenet"] default_marks_for_config_name: dict[str, list[pytest.MarkDecorator]] = { "imagenet32": [pytest.mark.slow], @@ -49,6 +53,15 @@ reason="Expects to be run on the Mila cluster for now", ), ], + "imagenet": [ + pytest.mark.slow, + pytest.mark.xfail( + not (NETWORK_DIR and (NETWORK_DIR / "datasets/imagenet").exists()), + strict=True, + raises=hydra.errors.InstantiationException, + reason="Expects to be run on a cluster with the ImageNet dataset.", + ), + ], "rl": [ pytest.mark.xfail( strict=False, From 5a9b70fa22bf4631dfc5ef0842c1f56ea10f52d6 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 18 Jun 2024 11:54:34 -0400 Subject: [PATCH 11/27] Fix some issues, fix defaults Signed-off-by: Fabrice Normandin --- project/configs/datamodule/imagenet.yaml | 2 +- .../image_classification/imagenet.py | 64 +++++++++++-------- project/datamodules/vision/base.py | 46 +++++++++---- 3 files changed, 72 insertions(+), 40 deletions(-) diff --git a/project/configs/datamodule/imagenet.yaml b/project/configs/datamodule/imagenet.yaml index 78ae4c6d..3e82c78b 100644 --- a/project/configs/datamodule/imagenet.yaml +++ b/project/configs/datamodule/imagenet.yaml @@ -1,4 +1,4 @@ defaults: -- vision + - vision _target_: project.datamodules.ImageNetDataModule # todo: add good configuration options here. diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py index ce1f76b9..8f194ba3 100644 --- a/project/datamodules/image_classification/imagenet.py +++ b/project/datamodules/image_classification/imagenet.py @@ -22,6 +22,7 @@ from torchvision.transforms import v2 as transform_lib from project.datamodules.vision.base import VisionDataModule +from project.utils.env_vars import NUM_WORKERS, DATA_DIR from project.utils.types import C, H, StageStr, W from project.utils.types.protocols import Module @@ -29,7 +30,9 @@ def imagenet_normalization(): - return transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + return transform_lib.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) type ClassIndex = int @@ -64,11 +67,10 @@ class ImageNetDataModule(VisionDataModule): def __init__( self, - data_dir: str | Path | None = None, + data_dir: str | Path = DATA_DIR, *, - val_split: int - | float = 0.01, # save `val_split`% of the training data *of each class* for validation. - num_workers: int | None = None, + val_split: int | float = 0.01, + num_workers: int = NUM_WORKERS, normalize: bool = False, image_size: int = 224, batch_size: int = 32, @@ -138,7 +140,9 @@ def setup(self, stage: StageStr | None = None) -> None: logger.debug(f"Setup ImageNet datamodule for {stage=}") super().setup(stage) - def _split_dataset(self, dataset: ImageNet, train: bool = True) -> torch.utils.data.Dataset: + def _split_dataset( + self, dataset: ImageNet, train: bool = True + ) -> torch.utils.data.Dataset: class_item_indices: dict[ClassIndex, list[ImageIndex]] = defaultdict(list) for dataset_index, y in enumerate(dataset.targets): class_item_indices[y].append(dataset_index) @@ -261,6 +265,8 @@ def prepare_imagenet( train_archive_file_name = "ILSVRC2012_img_train.tar" devkit_file_name = "ILSVRC2012_devkit_t12.tar.gz" md5sums_file_name = "md5sums" + if not root.exists(): + root.mkdir(parents=True) def _symlink_if_needed(filename: str, network_imagenet_dir: Path): if not (symlink := root / filename).exists(): @@ -283,11 +289,14 @@ def _symlink_if_needed(filename: str, network_imagenet_dir: Path): train_dir = root / "train" train_dir.mkdir(exist_ok=True, parents=True) train_archive = network_imagenet_dir / train_archive_file_name + previously_extracted_dirs_file = train_dir / ".previously_extracted_dirs.txt" _extract_train_archive( train_archive=train_archive, train_dir=train_dir, - previously_extracted_dirs_file=root / "previously_extracted_dirs.txt", + previously_extracted_dirs_file=previously_extracted_dirs_file, ) + if previously_extracted_dirs_file.exists(): + previously_extracted_dirs_file.unlink() # OR: could just reuse the equivalent-ish from torchvision, but which doesn't support # resuming after an interrupt. @@ -336,17 +345,24 @@ def _extract_train_archive( *, train_archive: Path, train_dir: Path, previously_extracted_dirs_file: Path ) -> None: # The ImageNet train archive is a tarfile of tarfiles (one for each class). - logger.debug("Extracting the ImageNet train archive using Olexa's tar magic in python form...") + logger.debug( + "Extracting the ImageNet train archive using Olexa's tar magic in python form..." + ) + train_dir.mkdir(exist_ok=True, parents=True) # Save a small text file or something that tells us which subdirs are # done extracting so we can just skip ahead to the right directory? previously_extracted_dirs: set[str] = set() + if previously_extracted_dirs_file.exists(): previously_extracted_dirs = set( stripped_line for line in previously_extracted_dirs_file.read_text().splitlines() if (stripped_line := line.strip()) ) + if len(previously_extracted_dirs) == 1000: + logger.info("Train archive already fully extracted. Skipping.") + return logger.debug( f"{len(previously_extracted_dirs)} directories have already been fully extracted." ) @@ -354,19 +370,17 @@ def _extract_train_archive( "\n".join(sorted(previously_extracted_dirs)) + "\n" ) - if len(previously_extracted_dirs) == 1000: + elif len(list(train_dir.iterdir())) == 1000: logger.info("Train archive already fully extracted. Skipping.") return with tarfile.open(train_archive, mode="r") as train_tarfile: - for class_id, member in enumerate( - tqdm.tqdm( - train_tarfile, - total=1000, # hard-coded here, since we know there are 1000 folders. - desc="Extracting train archive", - unit="Directories", - position=0, - ) + for member in tqdm.tqdm( + train_tarfile, + total=1000, # hard-coded here, since we know there are 1000 folders. + desc="Extracting train archive", + unit="Directories", + position=0, ): if member.name in previously_extracted_dirs: continue @@ -377,18 +391,14 @@ def _extract_train_archive( class_subdir = train_dir / member.name.replace(".tar", "") class_subdir_existed = class_subdir.exists() if class_subdir_existed: - files_in_subdir = set(p.name for p in class_subdir.iterdir()) + # Remove all the (potentially partially constructed) files in the directory. + logger.debug(f"Removing partially-constructed dir {class_subdir}") + shutil.rmtree(class_subdir, ignore_errors=False) else: class_subdir.mkdir(parents=True, exist_ok=True) - files_in_subdir = set() - - with tarfile.open(fileobj=buffer, mode="r|*") as sub_tarfile: - for tarinfo in sub_tarfile: - image_file_path = class_subdir / tarinfo.name - if files_in_subdir and image_file_path.name in files_in_subdir: - # Image file is already in the directory. - continue - sub_tarfile.extract(tarinfo, class_subdir) + + with tarfile.open(fileobj=buffer, mode="r|*") as class_tarfile: + class_tarfile.extractall(class_subdir, filter="data") # Alternative: .extractall with a list of members to extract: # members = sub_tarfile.getmembers() # note: loads the full archive. diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index b75d0a9d..74f9f0db 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -40,7 +40,7 @@ def __init__( self, data_dir: str | Path = DATA_DIR, val_split: int | float = 0.2, - num_workers: int | None = NUM_WORKERS, + num_workers: int = NUM_WORKERS, normalize: bool = False, batch_size: int = 32, seed: int = 42, @@ -100,9 +100,15 @@ def __init__( # todo: what about the shuffling at each epoch? _rng = torch.Generator(device="cpu").manual_seed(self.seed) - self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) - self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) - self.test_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) + self.train_dl_rng_seed = int( + torch.randint(0, int(1e6), (1,), generator=_rng).item() + ) + self.val_dl_rng_seed = int( + torch.randint(0, int(1e6), (1,), generator=_rng).item() + ) + self.test_dl_rng_seed = int( + torch.randint(0, int(1e6), (1,), generator=_rng).item() + ) self.test_dataset_cls = self.dataset_cls @@ -150,7 +156,9 @@ def setup(self, stage: StageStr | None = None) -> None: if stage == "test" or stage is None: logger.debug(f"creating test dataset with kwargs {self.train_kwargs}") - self.dataset_test = self.test_dataset_cls(str(self.data_dir), **self.test_kwargs) + self.dataset_test = self.test_dataset_cls( + str(self.data_dir), **self.test_kwargs + ) def _split_dataset(self, dataset: VisionDataset, train: bool = True) -> Dataset: """Splits the dataset into train and validation set.""" @@ -182,7 +190,9 @@ def _get_splits(self, len_dataset: int) -> list[int]: def default_transforms(self) -> Callable: """Default transform for the dataset.""" - def train_dataloader[**P]( + def train_dataloader[ + **P + ]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -202,7 +212,9 @@ def train_dataloader[**P]( ), ) - def val_dataloader[**P]( + def val_dataloader[ + **P + ]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -214,10 +226,15 @@ def val_dataloader[**P]( self.dataset_val, _dataloader_fn=_dataloader_fn, *args, - **(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs), + **( + dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) + | kwargs + ), ) - def test_dataloader[**P]( + def test_dataloader[ + **P + ]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -231,10 +248,15 @@ def test_dataloader[**P]( self.dataset_test, _dataloader_fn=_dataloader_fn, *args, - **(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs), + **( + dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) + | kwargs + ), ) - def _data_loader[**P]( + def _data_loader[ + **P + ]( self, dataset: Dataset, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, @@ -247,7 +269,7 @@ def _data_loader[**P]( num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, - persistent_workers=True if self.num_workers > 0 else False, + persistent_workers=(self.num_workers or 0) > 0, ) | dataloader_kwargs ) From aa1f6dba7352f71d2c46575461fef23423b7bd4a Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 18 Jun 2024 12:01:00 -0400 Subject: [PATCH 12/27] Minor touchup Signed-off-by: Fabrice Normandin --- project/datamodules/datamodules_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index 5667990e..8737df7d 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -3,7 +3,10 @@ import matplotlib.pyplot as plt import pytest -from tensor_regression.fixture import TensorRegressionFixture, get_test_source_and_temp_file_paths +from tensor_regression.fixture import ( + TensorRegressionFixture, + get_test_source_and_temp_file_paths, +) from torch import Tensor from project.utils.testutils import run_for_all_datamodules @@ -68,7 +71,10 @@ def test_first_batch( fig.suptitle(f"First batch of datamodule {type(datamodule).__name__}") figure_path, _ = get_test_source_and_temp_file_paths( - extension=".png", request=request, original_datadir=original_datadir, datadir=datadir + extension=".png", + request=request, + original_datadir=original_datadir, + datadir=datadir, ) figure_path.parent.mkdir(exist_ok=True, parents=True) fig.savefig(figure_path) From dadd4673c115f12a5b3c656c70ec18e4d6bd2ece Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 18 Jun 2024 20:00:12 +0000 Subject: [PATCH 13/27] Touchups and import fixes Signed-off-by: Fabrice Normandin --- project/datamodules/__init__.py | 2 + .../image_classification/imagenet32.py | 9 ++- project/datamodules/vision/base.py | 64 +++++++------------ project/utils/testutils.py | 28 ++++---- 4 files changed, 43 insertions(+), 60 deletions(-) diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index 52e8d92c..1b0a558e 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -3,6 +3,7 @@ from .image_classification.fashion_mnist import FashionMNISTDataModule from .image_classification.imagenet import ImageNetDataModule from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization +from .image_classification.inaturalist import INaturalistDataModule from .image_classification.mnist import MNISTDataModule from .vision.base import VisionDataModule @@ -10,6 +11,7 @@ "cifar10_normalization", "CIFAR10DataModule", "FashionMNISTDataModule", + "INaturalistDataModule", "ImageClassificationDataModule", "imagenet32_normalization", "ImageNet32DataModule", diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index 825ba493..795c6bd7 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -16,6 +16,7 @@ from torchvision.datasets import VisionDataset from torchvision.transforms import v2 as transforms +from project.utils.env_vars import SCRATCH from project.utils.types import C, H, StageStr, W from ..vision.base import VisionDataModule @@ -177,10 +178,10 @@ class ImageNet32DataModule(VisionDataModule): def __init__( self, data_dir: str | Path, - readonly_datasets_dir: str | Path | None = None, + readonly_datasets_dir: str | Path | None = SCRATCH, val_split: int | float = -1, num_images_per_val_class: int | None = 50, - num_workers: int | None = 0, + num_workers: int = 0, normalize: bool = False, batch_size: int = 32, seed: int = 42, @@ -221,7 +222,9 @@ def __init__( # ImageNetDataModule uses num_imgs_per_val_class: int = 50, which makes sense! Here # however we're using probably more than that for validation. - self.EXTRA_ARGS["readonly_datasets_dir"] = readonly_datasets_dir + self.train_kwargs["readonly_datasets_dir"] = readonly_datasets_dir + self.valid_kwargs["readonly_datasets_dir"] = readonly_datasets_dir + self.test_kwargs["readonly_datasets_dir"] = readonly_datasets_dir self.dataset_train: ImageNet32Dataset | Subset self.dataset_val: ImageNet32Dataset | Subset self.dataset_test: ImageNet32Dataset | Subset diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index 74f9f0db..a8f1e377 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -84,15 +84,9 @@ def __init__( self.test_transforms = test_transforms self.EXTRA_ARGS = kwargs - self.train_kwargs = self.EXTRA_ARGS | { - "transform": self.train_transforms or self.default_transforms() - } - self.valid_kwargs = self.EXTRA_ARGS | { - "transform": self.val_transforms or self.default_transforms() - } - self.test_kwargs = self.EXTRA_ARGS | { - "transform": self.test_transforms or self.default_transforms() - } + self.train_kwargs: dict = self.EXTRA_ARGS + self.valid_kwargs: dict = self.EXTRA_ARGS + self.test_kwargs: dict = self.EXTRA_ARGS if _has_constructor_argument(self.dataset_cls, "train"): self.train_kwargs["train"] = True self.valid_kwargs["train"] = True @@ -100,15 +94,9 @@ def __init__( # todo: what about the shuffling at each epoch? _rng = torch.Generator(device="cpu").manual_seed(self.seed) - self.train_dl_rng_seed = int( - torch.randint(0, int(1e6), (1,), generator=_rng).item() - ) - self.val_dl_rng_seed = int( - torch.randint(0, int(1e6), (1,), generator=_rng).item() - ) - self.test_dl_rng_seed = int( - torch.randint(0, int(1e6), (1,), generator=_rng).item() - ) + self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) + self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) + self.test_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) self.test_dataset_cls = self.dataset_cls @@ -135,6 +123,16 @@ def prepare_data(self) -> None: ) self.test_dataset_cls(str(self.data_dir), **test_kwargs) + self.train_kwargs = self.EXTRA_ARGS | { + "transform": self.train_transforms or self.default_transforms() + } + self.valid_kwargs = self.EXTRA_ARGS | { + "transform": self.val_transforms or self.default_transforms() + } + self.test_kwargs = self.EXTRA_ARGS | { + "transform": self.test_transforms or self.default_transforms() + } + def setup(self, stage: StageStr | None = None) -> None: """Creates train, val, and test dataset.""" if stage in ["fit", "validate"] or stage is None: @@ -156,9 +154,7 @@ def setup(self, stage: StageStr | None = None) -> None: if stage == "test" or stage is None: logger.debug(f"creating test dataset with kwargs {self.train_kwargs}") - self.dataset_test = self.test_dataset_cls( - str(self.data_dir), **self.test_kwargs - ) + self.dataset_test = self.test_dataset_cls(str(self.data_dir), **self.test_kwargs) def _split_dataset(self, dataset: VisionDataset, train: bool = True) -> Dataset: """Splits the dataset into train and validation set.""" @@ -190,9 +186,7 @@ def _get_splits(self, len_dataset: int) -> list[int]: def default_transforms(self) -> Callable: """Default transform for the dataset.""" - def train_dataloader[ - **P - ]( + def train_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -212,9 +206,7 @@ def train_dataloader[ ), ) - def val_dataloader[ - **P - ]( + def val_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -226,15 +218,10 @@ def val_dataloader[ self.dataset_val, _dataloader_fn=_dataloader_fn, *args, - **( - dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) - | kwargs - ), + **(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs), ) - def test_dataloader[ - **P - ]( + def test_dataloader[**P]( self, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, *args: P.args, @@ -248,15 +235,10 @@ def test_dataloader[ self.dataset_test, _dataloader_fn=_dataloader_fn, *args, - **( - dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) - | kwargs - ), + **(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs), ) - def _data_loader[ - **P - ]( + def _data_loader[**P]( self, dataset: Dataset, _dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader, diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 73fb02b7..990d2cfa 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -24,8 +24,7 @@ from torch import Tensor, nn from torch.optim import Optimizer -from project.configs import Config, cs -from project.configs.datamodule import DATA_DIR +from project.configs import Config from project.datamodules.image_classification import ( ImageClassificationDataModule, ) @@ -47,7 +46,7 @@ "inaturalist": [ pytest.mark.slow, pytest.mark.xfail( - not Path("/network/datasets/inat").exists(), + not (NETWORK_DIR and (NETWORK_DIR / "datasets/inat").exists()), strict=True, raises=hydra.errors.InstantiationException, reason="Expects to be run on the Mila cluster for now", @@ -62,17 +61,6 @@ reason="Expects to be run on a cluster with the ImageNet dataset.", ), ], - "rl": [ - pytest.mark.xfail( - strict=False, - raises=AssertionError, - # match="Shapes are not the same." - reason="Isn't entirely deterministic yet.", - ), - ], - "moving_mnist": [ - (pytest.mark.slow if not (DATA_DIR / "MovingMNIST").exists() else pytest.mark.timeout(5)) - ], } """Dict with some default marks for some configs name.""" @@ -169,11 +157,15 @@ def get_all_algorithm_names() -> list[str]: return get_all_configs_in_group("algorithm") -def get_type_for_config_name(config_group: str, config_name: str, _cs: ConfigStore = cs) -> type: +def get_type_for_config_name( + config_group: str, config_name: str, _cs: ConfigStore | None = None +) -> type: """Returns the class that is to be instantiated by the given config name. In the case of inner dataclasses (e.g. Model.HParams), this returns the outer class (Model). """ + if _cs is None: + from project.configs import cs as _cs config_loader = get_config_loader() _, caching_repo = config_loader._parse_overrides_and_create_caching_repo( @@ -288,7 +280,11 @@ def test_network_output_is_reproducible(network: nn.Module, x: Tensor): def get_all_datamodule_names() -> list[str]: """Retrieves the names of all the datamodules that are saved in the ConfigStore of Hydra.""" - return get_all_configs_in_group("datamodule") + datamodules = get_all_configs_in_group("datamodule") + # todo: automatically detect which ones are configs for ABCs and remove them? + if "vision" in datamodules: + datamodules.remove("vision") + return datamodules def get_all_datamodule_names_params(): From 0969afc361651d0e0eb9624443a083b447500a7f Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 14:31:34 +0000 Subject: [PATCH 14/27] Switch to full config files for datamodule configs Signed-off-by: Fabrice Normandin --- project/algorithms/bases/algorithm_test.py | 2 - project/configs/config.py | 6 +++ project/configs/datamodule/__init__.py | 43 +++++++++++-------- project/configs/datamodule/inaturalist.yaml | 1 - project/configs/datamodule/vision.yaml | 9 ++++ project/conftest.py | 4 +- .../image_classification/inaturalist.py | 9 ++-- project/utils/env_vars.py | 4 ++ project/utils/testutils.py | 21 +++------ 9 files changed, 55 insertions(+), 44 deletions(-) create mode 100644 project/configs/datamodule/vision.yaml diff --git a/project/algorithms/bases/algorithm_test.py b/project/algorithms/bases/algorithm_test.py index 5df5ee95..e43a0e6a 100644 --- a/project/algorithms/bases/algorithm_test.py +++ b/project/algorithms/bases/algorithm_test.py @@ -297,11 +297,9 @@ def fork_rng(): def datamodule_name(self, request: pytest.FixtureRequest): """Fixture that gives the name of a datamodule to use.""" datamodule_name = request.param - if datamodule_name in default_marks_for_config_name: for marker in default_marks_for_config_name[datamodule_name]: request.applymarker(marker) - self._skip_if_unsupported("datamodule", datamodule_name, skip_or_xfail=SKIP_OR_XFAIL) return datamodule_name diff --git a/project/configs/config.py b/project/configs/config.py index 5ff41808..6a39e912 100644 --- a/project/configs/config.py +++ b/project/configs/config.py @@ -3,6 +3,12 @@ from logging import getLogger as get_logger from typing import Any, Literal +from omegaconf import OmegaConf + +from project.utils.env_vars import get_constant + +OmegaConf.register_new_resolver("constant", get_constant) + logger = get_logger(__name__) LogLevel = Literal["debug", "info", "warning", "error", "critical"] diff --git a/project/configs/datamodule/__init__.py b/project/configs/datamodule/__init__.py index 61bd5077..b8bd32a5 100644 --- a/project/configs/datamodule/__init__.py +++ b/project/configs/datamodule/__init__.py @@ -1,12 +1,9 @@ from logging import getLogger as get_logger from pathlib import Path -from hydra_zen import hydrated_dataclass, instantiate, store +from hydra_zen import store -from project.datamodules import ( - VisionDataModule, -) -from project.utils.env_vars import DATA_DIR, NETWORK_DIR, NUM_WORKERS +from project.utils.env_vars import NETWORK_DIR logger = get_logger(__name__) @@ -23,22 +20,30 @@ # TODO: Make it possible to extend a structured base via yaml files as well as adding new fields # (for example, ImagetNet32DataModule has a new constructor argument which can't be set atm in the # config). +datamodule_store = store(group="datamodule") -@hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) -class VisionDataModuleConfig: - data_dir: str | None = str(torchvision_dir or DATA_DIR) - val_split: int | float = 0.1 # NOTE: reduced from default of 0.2 - num_workers: int = NUM_WORKERS - normalize: bool = True # NOTE: Set to True by default instead of False - batch_size: int = 32 - seed: int = 42 - shuffle: bool = True # NOTE: Set to True by default instead of False. - pin_memory: bool = True # NOTE: Set to True by default instead of False. - drop_last: bool = False +# @hydrated_dataclass(target=VisionDataModule, populate_full_signature=True) +# class VisionDataModuleConfig: +# data_dir: str | None = str(torchvision_dir or DATA_DIR) +# val_split: int | float = 0.1 # NOTE: reduced from default of 0.2 +# num_workers: int = NUM_WORKERS +# normalize: bool = True # NOTE: Set to True by default instead of False +# batch_size: int = 32 +# seed: int = 42 +# shuffle: bool = True # NOTE: Set to True by default instead of False. +# pin_memory: bool = True # NOTE: Set to True by default instead of False. +# drop_last: bool = False - __call__ = instantiate +# __call__ = instantiate -datamodule_store = store(group="datamodule") -datamodule_store(VisionDataModuleConfig, name="vision") +# datamodule_store(VisionDataModuleConfig, name="vision") + +# inaturalist_config = hydra_zen.builds( +# INaturalistDataModule, +# builds_bases=(VisionDataModuleConfig,), +# populate_full_signature=True, +# dataclass_name=f"{INaturalistDataModule.__name__}Config", +# ) +# datamodule_store(inaturalist_config, name="inaturalist") diff --git a/project/configs/datamodule/inaturalist.yaml b/project/configs/datamodule/inaturalist.yaml index 5cea87f5..5670be6b 100644 --- a/project/configs/datamodule/inaturalist.yaml +++ b/project/configs/datamodule/inaturalist.yaml @@ -1,6 +1,5 @@ defaults: - vision _target_: project.datamodules.INaturalistDataModule -data_dir: null version: "2021_train" target_type: "full" diff --git a/project/configs/datamodule/vision.yaml b/project/configs/datamodule/vision.yaml new file mode 100644 index 00000000..e3f10b79 --- /dev/null +++ b/project/configs/datamodule/vision.yaml @@ -0,0 +1,9 @@ +_target_: project.datamodules.VisionDataModule +data_dir: ${constant:DATA_DIR} +num_workers: ${constant:NUM_WORKERS} +val_split: 0.1 # NOTE: reduced from default of 0.2 +normalize: True # NOTE: Set to True by default instead of False +shuffle: True # NOTE: Set to True by default instead of False. +pin_memory: True # NOTE: Set to True by default instead of False. +seed: 42 +batch_size: 64 diff --git a/project/conftest.py b/project/conftest.py index 0a69188d..b5f968f1 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -21,7 +21,6 @@ from torch.utils.data import DataLoader from project.configs.config import Config -from project.configs.datamodule import DATA_DIR from project.datamodules.image_classification import ( ImageClassificationDataModule, ) @@ -35,6 +34,7 @@ setup_experiment, setup_logging, ) +from project.utils.env_vars import DATA_DIR from project.utils.hydra_utils import resolve_dictconfig from project.utils.testutils import default_marks_for_config_name from project.utils.types import is_sequence_of @@ -339,7 +339,7 @@ def algorithm_name(request: pytest.FixtureRequest) -> str | None: @pytest.fixture(scope="session") def datamodule_name(request: pytest.FixtureRequest) -> str | None: - datamodule_config_name = getattr(request, "param", None) + datamodule_config_name = getattr(request, "param") if datamodule_config_name: _add_default_marks_for_config_name(datamodule_config_name, request) return datamodule_config_name diff --git a/project/datamodules/image_classification/inaturalist.py b/project/datamodules/image_classification/inaturalist.py index 9bf192cd..a1090a72 100644 --- a/project/datamodules/image_classification/inaturalist.py +++ b/project/datamodules/image_classification/inaturalist.py @@ -10,7 +10,7 @@ from torchvision.datasets import INaturalist from project.datamodules.image_classification.base import ImageClassificationDataModule -from project.utils.env_vars import SLURM_TMPDIR +from project.utils.env_vars import DATA_DIR, NUM_WORKERS, SLURM_TMPDIR from project.utils.types import C, H, W logger = get_logger(__name__) @@ -44,9 +44,9 @@ class INaturalistDataModule(ImageClassificationDataModule): def __init__( self, - data_dir: str | Path | None = None, + data_dir: str | Path = DATA_DIR, val_split: int | float = 0.1, - num_workers: int | None = None, + num_workers: int = NUM_WORKERS, normalize: bool = False, batch_size: int = 32, seed: int = 42, @@ -105,7 +105,8 @@ def __init__( if not isinstance(target_type, list): self.num_classes = None - if version == "2021_train_mini" and target_type == "full": + # todo: double-check that the 2021_train split also has 10_000 classes. + if version in ["2021_train_mini", "2021_train"] and target_type == "full": self.num_classes = 10_000 if isinstance(train_transforms, T.Compose): channels = 3 diff --git a/project/utils/env_vars.py b/project/utils/env_vars.py index 78ebdb38..cc9d663d 100644 --- a/project/utils/env_vars.py +++ b/project/utils/env_vars.py @@ -35,6 +35,10 @@ """Directory where datasets should be extracted.""" +def get_constant(name: str): + return globals()[name] + + NUM_WORKERS = int( os.environ.get( "SLURM_CPUS_PER_TASK", diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 990d2cfa..2b50764a 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -30,7 +30,7 @@ ) from project.datamodules.vision.base import VisionDataModule from project.experiment import instantiate_trainer -from project.utils.env_vars import NETWORK_DIR, SLURM_JOB_ID +from project.utils.env_vars import NETWORK_DIR from project.utils.hydra_utils import get_attr, get_outer_class from project.utils.types import PhaseStr from project.utils.types.protocols import DataModule @@ -45,10 +45,10 @@ "imagenet32": [pytest.mark.slow], "inaturalist": [ pytest.mark.slow, - pytest.mark.xfail( + pytest.mark.skipif( not (NETWORK_DIR and (NETWORK_DIR / "datasets/inat").exists()), - strict=True, - raises=hydra.errors.InstantiationException, + # strict=True, + # raises=hydra.errors.InstantiationException, reason="Expects to be run on the Mila cluster for now", ), ], @@ -300,18 +300,7 @@ def get_all_datamodule_names_params(): marks=[ pytest.mark.xdist_group(name=dm_name), ] - + ([pytest.mark.slow] if dm_name in SLOW_DATAMODULES else []) - + ( - [ - pytest.mark.xfail( - SLURM_JOB_ID is None, - raises=NotImplementedError, - reason="Needs to be run on the Mila cluster atm.", - ) - ] - if dm_name == "inaturalist" - else [] - ), + + default_marks_for_config_name.get(dm_name, []), ) for dm_name in dm_names ] From 675a865d365a92ba868d290026a5775dccece5c3 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 15:30:51 +0000 Subject: [PATCH 15/27] Add marks for combinations of configs Signed-off-by: Fabrice Normandin --- project/algorithms/bases/algorithm_test.py | 25 +++++++++++++++++++--- project/conftest.py | 15 ++++++++++++- project/utils/testutils.py | 23 +++++++++++++++++--- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/project/algorithms/bases/algorithm_test.py b/project/algorithms/bases/algorithm_test.py index e43a0e6a..32583d5d 100644 --- a/project/algorithms/bases/algorithm_test.py +++ b/project/algorithms/bases/algorithm_test.py @@ -36,6 +36,7 @@ from project.main import main from project.utils.hydra_utils import resolve_dictconfig from project.utils.testutils import ( + default_marks_for_config_combinations, default_marks_for_config_name, get_all_datamodule_names_params, get_all_network_names, @@ -146,8 +147,11 @@ def get_testing_callbacks(self) -> list[TestingCallback]: AllParamsShouldHaveGradients(), ] + # todo: make this much faster to run! + # Also, some combinations don't work, e.g. `imagenet + fcnet`, there are nans in the network. + @pytest.mark.slow - @pytest.mark.timeout(10) # todo: make this much faster to run! + # @pytest.mark.timeout(10) def test_overfit_training_batch( self, algorithm: AlgorithmType, @@ -323,7 +327,11 @@ def network_name(self, request: pytest.FixtureRequest): @pytest.fixture(scope="class") def _hydra_config( - self, datamodule_name: str, network_name: str, tmp_path_factory: pytest.TempPathFactory + self, + datamodule_name: str, + network_name: str, + tmp_path_factory: pytest.TempPathFactory, + request: pytest.FixtureRequest, ) -> DictConfig: """Fixture that gives the Hydra configuration for an experiment that uses this algorithm, datamodule, and network. @@ -336,6 +344,16 @@ def _hydra_config( # todo: Get the name of the algorithm from the hydra config? algorithm_name = self.algorithm_name + + combination = set([datamodule_name, network_name, algorithm_name]) + for configs, marks in default_marks_for_config_combinations.items(): + configs = set(configs) + if combination >= configs: + logger.debug(f"Applying markers because {combination} contains {configs}") + # There is a combination of potentially unsupported configs here. + for mark in marks: + request.applymarker(mark) + with setup_hydra_for_tests_and_compose( all_overrides=[ f"algorithm={algorithm_name}", @@ -652,7 +670,8 @@ def on_train_batch_end( parameters_with_nans = [ name for name, param in pl_module.named_parameters() if param.isnan().any() ] - assert not parameters_with_nans + if parameters_with_nans: + raise RuntimeError(f"Parameters {parameters_with_nans} contain NaNs!") parameters_with_nans_in_grad = [ name diff --git a/project/conftest.py b/project/conftest.py index b5f968f1..993a0f35 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -36,7 +36,10 @@ ) from project.utils.env_vars import DATA_DIR from project.utils.hydra_utils import resolve_dictconfig -from project.utils.testutils import default_marks_for_config_name +from project.utils.testutils import ( + default_marks_for_config_combinations, + default_marks_for_config_name, +) from project.utils.types import is_sequence_of from project.utils.types.protocols import DataModule @@ -362,9 +365,19 @@ def experiment_dictconfig( datamodule_name: str | None, network_name: str | None, overrides: tuple[str, ...], + request: pytest.FixtureRequest, ) -> Generator[DictConfig, None, None]: tmp_path = tmp_path_factory.mktemp("experiment_testing") + combination = set([datamodule_name, network_name, algorithm_name]) + for configs, marks in default_marks_for_config_combinations.items(): + configs = set(configs) + if combination >= configs: + logger.debug(f"Applying markers because {combination} contains {configs}") + # There is a combination of potentially unsupported configs here. + for mark in marks: + request.applymarker(mark) + default_overrides = [ # NOTE: if we were to run the test in a slurm job, this wouldn't make sense. "seed=42", diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 2b50764a..4c11d0ff 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -36,10 +36,11 @@ from project.utils.types.protocols import DataModule from project.utils.utils import get_device +logger = get_logger(__name__) + on_github_ci = "GITHUB_ACTIONS" in os.environ on_self_hosted_github_ci = on_github_ci and "self-hosted" in os.environ.get("RUNNER_LABELS", "") -SLOW_DATAMODULES = ["inaturalist", "imagenet32", "imagenet"] default_marks_for_config_name: dict[str, list[pytest.MarkDecorator]] = { "imagenet32": [pytest.mark.slow], @@ -64,8 +65,24 @@ } """Dict with some default marks for some configs name.""" - -logger = get_logger(__name__) +default_marks_for_config_combinations: dict[tuple[str, ...], list[pytest.MarkDecorator]] = { + ("imagenet", "fcnet"): [ + pytest.mark.xfail( + reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters." + ) + ], + ("imagenet", "jax_fcnet"): [ + pytest.mark.xfail( + reason="FcNet shouldn't be applied to the ImageNet datamodule. It can lead to nans in the parameters." + ) + ], + ("imagenet", "jax_cnn"): [ + pytest.mark.xfail( + reason="todo: parameters contain nans when overfitting on one batch? Maybe we're " + "using too many iterations?" + ) + ], +} def parametrized_fixture(name: str, values: Sequence, ids=None, **kwargs): From ac5f6e1ee79bcc4f681d11bb7e261c56eae8ef0d Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 15:55:21 +0000 Subject: [PATCH 16/27] Add regression file for ImageNet batch Signed-off-by: Fabrice Normandin --- .../test_first_batch/imagenet.yaml | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 project/datamodules/datamodules_test/test_first_batch/imagenet.yaml diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml new file mode 100644 index 00000000..ca27ece1 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: -5237812157965531059 + max: 2.64 + mean: -0.067 + min: -2.118 + shape: + - 32 + - 3 + - 224 + - 224 + sum: -320363.531 +'1': + device: cpu + hash: 2821605806225718212 + max: 982 + mean: 511.219 + min: 0 + shape: + - 32 + sum: 16359 From af79c0688cdd55da999edc7713de24dce8ea78dc Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 15:55:56 +0000 Subject: [PATCH 17/27] Add pytest-testmon dev dependency Signed-off-by: Fabrice Normandin --- pdm.lock | 25 ++++++++++++++++++++----- pyproject.toml | 1 + 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/pdm.lock b/pdm.lock index 12a8b637..7edd8611 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:805e3f5f1a98de3530f8ec547141537d7d30b7d7d7ca5a3b5f9477809327ecdd" +content_hash = "sha256:df6ae2c882d59bd85dcb8e915205e54db01eab36f4a4910af80acaa945817388" [[package]] name = "absl-py" @@ -2275,6 +2275,21 @@ files = [ {file = "pytest_skip_slow-0.0.5-py3-none-any.whl", hash = "sha256:e2f6401d6ed0db3be1402622a7b24f7df14f61ebd26feda808a0d45433d4d474"}, ] +[[package]] +name = "pytest-testmon" +version = "2.1.1" +requires_python = ">=3.8" +summary = "selects tests affected by changed files and methods" +groups = ["dev"] +dependencies = [ + "coverage<8,>=6", + "pytest<9,>=5", +] +files = [ + {file = "pytest-testmon-2.1.1.tar.gz", hash = "sha256:8ebe2c3de42d99306ee54cd4536fed0fc48346a954420da904b18e8d59b5da98"}, + {file = "pytest_testmon-2.1.1-py3-none-any.whl", hash = "sha256:8271ca47bc8c80760c4fc7fd7895ea786b111bbb31f13eeea879a6fd11fe2226"}, +] + [[package]] name = "pytest-timeout" version = "2.3.1" @@ -2838,15 +2853,15 @@ files = [ [[package]] name = "torch-jax-interop" version = "0.0.4.post7.dev0" -requires_python = "<4.0,>=3.11" +requires_python = ">=3.11,<4.0" git = "https://www.github.com/lebrice/torch_jax_interop" revision = "7f0c72fe19d8bd4bd957f20dd90d77acd8178bd4" summary = "Utility to convert Tensors from Jax to Torch and vice-versa" groups = ["default"] dependencies = [ - "flax<1.0.0,>=0.8.4", - "jax[cuda12]<1.0.0,>=0.4.28", - "pytorch2jax<1.0.0,>=0.1.0", + "flax<0.9.0,>=0.8.4", + "jax[cuda12]<0.5.0,>=0.4.28", + "pytorch2jax<0.2.0,>=0.1.0", "torch<3.0.0,>=2.3.0", ] diff --git a/pyproject.toml b/pyproject.toml index 15b6ba14..13b0a7d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dev = [ "pytest-benchmark>=4.0.0", "pytest-cov>=5.0.0", "tensor-regression>=0.0.2.post3.dev0", + "pytest-testmon>=2.1.1", ] [[tool.pdm.source]] From bfd34cf7e2d61423b8c9d9850bbc32b1c580cf77 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 16:12:09 +0000 Subject: [PATCH 18/27] Fix pre-commit issue Signed-off-by: Fabrice Normandin --- project/datamodules/image_classification/imagenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py index 8f194ba3..f80356d5 100644 --- a/project/datamodules/image_classification/imagenet.py +++ b/project/datamodules/image_classification/imagenet.py @@ -22,7 +22,7 @@ from torchvision.transforms import v2 as transform_lib from project.datamodules.vision.base import VisionDataModule -from project.utils.env_vars import NUM_WORKERS, DATA_DIR +from project.utils.env_vars import DATA_DIR, NUM_WORKERS from project.utils.types import C, H, StageStr, W from project.utils.types.protocols import Module From f85ffb4338db3786b9bce1712a415d82e2a36dad Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 18:29:38 +0000 Subject: [PATCH 19/27] Fix bug in VisionDataModule.train/test/val kwargs Signed-off-by: Fabrice Normandin --- project/datamodules/datamodules_test.py | 3 +- .../image_classification/imagenet32.py | 54 +++++++------------ .../image_classification/imagenet32_test.py | 9 ++-- project/datamodules/vision/base.py | 32 +++++------ 4 files changed, 39 insertions(+), 59 deletions(-) diff --git a/project/datamodules/datamodules_test.py b/project/datamodules/datamodules_test.py index 8737df7d..e0e9666c 100644 --- a/project/datamodules/datamodules_test.py +++ b/project/datamodules/datamodules_test.py @@ -15,7 +15,8 @@ from ..utils.types.protocols import DataModule -@pytest.mark.timeout(25, func_only=True) +# @pytest.mark.timeout(25, func_only=True) +@pytest.mark.slow @run_for_all_datamodules() def test_first_batch( datamodule: DataModule, diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index 795c6bd7..1aab318c 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -11,12 +11,13 @@ import gdown import numpy as np +import torch from PIL import Image from torch.utils.data import DataLoader, Dataset, Subset from torchvision.datasets import VisionDataset from torchvision.transforms import v2 as transforms -from project.utils.env_vars import SCRATCH +from project.utils.env_vars import DATA_DIR, SCRATCH from project.utils.types import C, H, StageStr, W from ..vision.base import VisionDataModule @@ -177,7 +178,7 @@ class ImageNet32DataModule(VisionDataModule): def __init__( self, - data_dir: str | Path, + data_dir: Path = DATA_DIR, readonly_datasets_dir: str | Path | None = SCRATCH, val_split: int | float = -1, num_images_per_val_class: int | None = 50, @@ -194,7 +195,7 @@ def __init__( ) -> None: Path(data_dir).mkdir(parents=True, exist_ok=True) super().__init__( - data_dir=str(data_dir), + data_dir=data_dir, val_split=val_split, num_workers=num_workers, normalize=normalize, @@ -208,7 +209,6 @@ def __init__( test_transforms=test_transforms, ) self.num_images_per_val_class = num_images_per_val_class - if self.val_split == -1 and self.num_images_per_val_class is None: raise ValueError( "Can't have both `val_split` and `num_images_per_val_class` set to `None`!" @@ -235,9 +235,7 @@ def num_samples(self) -> int: def prepare_data(self) -> None: """Saves files to data_dir.""" - # NOTE: In our case, the download gives us both. No need to do it twice. - self.dataset_cls(self.data_dir, train=True, download=True, **self.EXTRA_ARGS) - self.dataset_cls(self.data_dir, train=False, download=True, **self.EXTRA_ARGS) + super().prepare_data() def setup(self, stage: StageStr | None = None) -> None: """Creates train, val, and test dataset.""" @@ -249,32 +247,17 @@ def setup(self, stage: StageStr | None = None) -> None: else: logger.debug("Setting up for all stages") - if stage in ["fit", "val", None]: - train_transforms = ( - self.default_transforms() - if self.train_transforms is None - else self.train_transforms - ) - val_transforms = ( - self.default_transforms() if self.val_transforms is None else self.val_transforms - ) - # Create the entire dataset twice. This is only needed because they have different - # transforms... - base_dataset = self.dataset_cls( - self.data_dir, - train=True, - transform=transforms.ToTensor(), - **self.EXTRA_ARGS, - ) - # Make sure they both use the same underlying data. (so we don't use twice as much - # memory, like the base-class does! + if stage in ["fit", "validate", None]: + base_dataset = self.dataset_cls(self.data_dir, **self.train_kwargs) + assert len(base_dataset) == 1_281_159 + base_dataset_train = copy.deepcopy(base_dataset) - base_dataset_train.transform = train_transforms + base_dataset_train.transform = self.train_transforms base_dataset_train.data = base_dataset.data base_dataset_train.targets = base_dataset.targets base_dataset_valid = copy.deepcopy(base_dataset) - base_dataset_valid.transform = val_transforms + base_dataset_valid.transform = self.val_transforms base_dataset_valid.data = base_dataset.data base_dataset_valid.targets = base_dataset.targets @@ -300,13 +283,13 @@ def setup(self, stage: StageStr | None = None) -> None: def default_transforms(self) -> Callable: """Default transform for the dataset.""" - if self.normalize: - in32_transforms = transforms.Compose( - [transforms.ToTensor(), imagenet32_normalization()] - ) - else: - in32_transforms = transforms.Compose([transforms.ToTensor()]) - return in32_transforms + return transforms.Compose( + [ + transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), + ] + + ([imagenet32_normalization()] if self.normalize else []) + ) def train_dataloader(self) -> DataLoader: """The train dataloader.""" @@ -331,6 +314,7 @@ def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: ) def _split_dataset(self, dataset: ImageNet32Dataset, train: bool = True) -> Subset: + assert self.val_split >= 0 split_dataset = super()._split_dataset(dataset, train=train) assert isinstance(split_dataset, Subset) return split_dataset diff --git a/project/datamodules/image_classification/imagenet32_test.py b/project/datamodules/image_classification/imagenet32_test.py index 0478fbfe..d4913da7 100644 --- a/project/datamodules/image_classification/imagenet32_test.py +++ b/project/datamodules/image_classification/imagenet32_test.py @@ -1,18 +1,17 @@ import itertools -from pathlib import Path import pytest -from project.utils.env_vars import SCRATCH +from project.utils.env_vars import DATA_DIR, SCRATCH from .imagenet32 import ImageNet32DataModule @pytest.mark.slow -def test_dataset_download_works(data_dir: Path): +def test_dataset_download_works(): batch_size = 16 datamodule = ImageNet32DataModule( - data_dir=data_dir, + data_dir=DATA_DIR, readonly_datasets_dir=SCRATCH, batch_size=batch_size, num_images_per_val_class=10, @@ -22,7 +21,7 @@ def test_dataset_download_works(data_dir: Path): datamodule.prepare_data() datamodule.setup(None) - expected_total = 1281159 + expected_total = 1_281_159 assert ( datamodule.num_samples == expected_total - datamodule.num_classes * datamodule.num_images_per_val_class diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index a8f1e377..c8419ae6 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -84,14 +84,6 @@ def __init__( self.test_transforms = test_transforms self.EXTRA_ARGS = kwargs - self.train_kwargs: dict = self.EXTRA_ARGS - self.valid_kwargs: dict = self.EXTRA_ARGS - self.test_kwargs: dict = self.EXTRA_ARGS - if _has_constructor_argument(self.dataset_cls, "train"): - self.train_kwargs["train"] = True - self.valid_kwargs["train"] = True - self.test_kwargs["train"] = False - # todo: what about the shuffling at each epoch? _rng = torch.Generator(device="cpu").manual_seed(self.seed) self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) @@ -104,6 +96,20 @@ def __init__( self.dataset_val: Dataset | None = None self.dataset_test: VisionDataset | None = None + self.train_kwargs = self.EXTRA_ARGS | { + "transform": self.train_transforms or self.default_transforms() + } + self.valid_kwargs = self.EXTRA_ARGS | { + "transform": self.val_transforms or self.default_transforms() + } + self.test_kwargs = self.EXTRA_ARGS | { + "transform": self.test_transforms or self.default_transforms() + } + if _has_constructor_argument(self.dataset_cls, "train"): + self.train_kwargs["train"] = True + self.valid_kwargs["train"] = True + self.test_kwargs["train"] = False + def prepare_data(self) -> None: """Saves files to data_dir.""" # Call with `train=True` and `train=False` if there is such an argument. @@ -123,16 +129,6 @@ def prepare_data(self) -> None: ) self.test_dataset_cls(str(self.data_dir), **test_kwargs) - self.train_kwargs = self.EXTRA_ARGS | { - "transform": self.train_transforms or self.default_transforms() - } - self.valid_kwargs = self.EXTRA_ARGS | { - "transform": self.val_transforms or self.default_transforms() - } - self.test_kwargs = self.EXTRA_ARGS | { - "transform": self.test_transforms or self.default_transforms() - } - def setup(self, stage: StageStr | None = None) -> None: """Creates train, val, and test dataset.""" if stage in ["fit", "validate"] or stage is None: From 247b87372d9a6ece2c36dfddc264ee659592d296 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 18:54:23 +0000 Subject: [PATCH 20/27] Fix issues with imagenet32+VisionDataModule Signed-off-by: Fabrice Normandin --- project/configs/datamodule/imagenet32.yaml | 6 ++--- .../test_first_batch/cifar10.yaml | 2 +- .../test_first_batch/imagenet.yaml | 18 ++++++------- .../image_classification/imagenet32.py | 12 +++------ .../image_classification/imagenet32_test.py | 2 +- project/datamodules/vision/base.py | 26 +++++++++---------- project/utils/testutils.py | 1 + 7 files changed, 31 insertions(+), 36 deletions(-) diff --git a/project/configs/datamodule/imagenet32.yaml b/project/configs/datamodule/imagenet32.yaml index 9c4bb7d7..208119be 100644 --- a/project/configs/datamodule/imagenet32.yaml +++ b/project/configs/datamodule/imagenet32.yaml @@ -1,11 +1,9 @@ defaults: - vision _target_: project.datamodules.ImageNet32DataModule -data_dir: "${oc.env:SCRATCH}/data" +data_dir: ${constant:SCRATCH} val_split: -1 -# TODO: Can't currently add this key since it isn't in the structured config for the -# `VisionDataModule`. -# num_images_per_val_class: 50 +num_images_per_val_class: 50 normalize: True train_transforms: _target_: project.datamodules.image_classification.imagenet32.imagenet32_train_transforms diff --git a/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml b/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml index e027265f..f798ebc1 100644 --- a/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml +++ b/project/datamodules/datamodules_test/test_first_batch/cifar10.yaml @@ -9,7 +9,7 @@ - 3 - 32 - 32 - sum: -2919.015 + sum: -2919.016 '1': device: cpu hash: 3692171093056153318 diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml index ca27ece1..6c2baa19 100644 --- a/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet.yaml @@ -1,21 +1,21 @@ '0': device: cpu - hash: -5237812157965531059 + hash: 3674008927974037273 max: 2.64 - mean: -0.067 + mean: -0.084 min: -2.118 shape: - - 32 + - 64 - 3 - 224 - 224 - sum: -320363.531 + sum: -809988.0 '1': device: cpu - hash: 2821605806225718212 - max: 982 - mean: 511.219 + hash: 3360823606619711831 + max: 988 + mean: 518.219 min: 0 shape: - - 32 - sum: 16359 + - 64 + sum: 33166 diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index 1aab318c..0ab5a744 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -207,6 +207,8 @@ def __init__( train_transforms=train_transforms, val_transforms=val_transforms, test_transforms=test_transforms, + # extra kwargs + readonly_datasets_dir=readonly_datasets_dir, ) self.num_images_per_val_class = num_images_per_val_class if self.val_split == -1 and self.num_images_per_val_class is None: @@ -220,11 +222,6 @@ def __init__( ) self.num_images_per_val_class = None - # ImageNetDataModule uses num_imgs_per_val_class: int = 50, which makes sense! Here - # however we're using probably more than that for validation. - self.train_kwargs["readonly_datasets_dir"] = readonly_datasets_dir - self.valid_kwargs["readonly_datasets_dir"] = readonly_datasets_dir - self.test_kwargs["readonly_datasets_dir"] = readonly_datasets_dir self.dataset_train: ImageNet32Dataset | Subset self.dataset_val: ImageNet32Dataset | Subset self.dataset_test: ImageNet32Dataset | Subset @@ -274,9 +271,7 @@ def setup(self, stage: StageStr | None = None) -> None: self.dataset_val = self._split_dataset(base_dataset_valid, train=False) if stage in ["test", None]: - test_transforms = ( - self.default_transforms() if self.test_transforms is None else self.test_transforms - ) + test_transforms = self.test_transforms or self.default_transforms() self.dataset_test = self.dataset_cls( self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS ) @@ -350,6 +345,7 @@ def imagenet32_train_transforms(): return transforms.Compose( [ transforms.ToImage(), + transforms.ToDtype(torch.float32, scale=True), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomCrop(size=32, padding=4, padding_mode="edge"), imagenet32_normalization(), diff --git a/project/datamodules/image_classification/imagenet32_test.py b/project/datamodules/image_classification/imagenet32_test.py index d4913da7..96696049 100644 --- a/project/datamodules/image_classification/imagenet32_test.py +++ b/project/datamodules/image_classification/imagenet32_test.py @@ -20,8 +20,8 @@ def test_dataset_download_works(): assert datamodule.val_split == -1 datamodule.prepare_data() datamodule.setup(None) - expected_total = 1_281_159 + assert ( datamodule.num_samples == expected_total - datamodule.num_classes * datamodule.num_images_per_val_class diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index c8419ae6..b3fa7542 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -12,6 +12,7 @@ from lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision.datasets import VisionDataset +from torchvision.transforms import v2 as transforms from project.utils.env_vars import DATA_DIR, NUM_WORKERS from project.utils.types import C, H, StageStr, W @@ -79,10 +80,13 @@ def __init__( self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last - self.train_transforms = train_transforms - self.val_transforms = val_transforms - self.test_transforms = test_transforms - self.EXTRA_ARGS = kwargs + self.train_transforms = train_transforms or self.default_transforms() + self.val_transforms = val_transforms or transforms.Compose( + [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)] + ) + self.test_transforms = test_transforms or transforms.Compose( + [transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True)] + ) # todo: what about the shuffling at each epoch? _rng = torch.Generator(device="cpu").manual_seed(self.seed) @@ -96,15 +100,11 @@ def __init__( self.dataset_val: Dataset | None = None self.dataset_test: VisionDataset | None = None - self.train_kwargs = self.EXTRA_ARGS | { - "transform": self.train_transforms or self.default_transforms() - } - self.valid_kwargs = self.EXTRA_ARGS | { - "transform": self.val_transforms or self.default_transforms() - } - self.test_kwargs = self.EXTRA_ARGS | { - "transform": self.test_transforms or self.default_transforms() - } + self.EXTRA_ARGS = kwargs + self.train_kwargs = self.EXTRA_ARGS | {"transform": self.train_transforms} + self.valid_kwargs = self.EXTRA_ARGS | {"transform": self.val_transforms} + self.test_kwargs = self.EXTRA_ARGS | {"transform": self.test_transforms} + if _has_constructor_argument(self.dataset_cls, "train"): self.train_kwargs["train"] = True self.valid_kwargs["train"] = True diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 4c11d0ff..1587b8d5 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -62,6 +62,7 @@ reason="Expects to be run on a cluster with the ImageNet dataset.", ), ], + "vision": [pytest.mark.skip(reason="Base class, shouldn't be instantiated.")], } """Dict with some default marks for some configs name.""" From cadd54f9491f5e44a19265e6cbf989bcd08db449 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 18:58:37 +0000 Subject: [PATCH 21/27] Add a regression file for first batch ImageNet32 Signed-off-by: Fabrice Normandin --- .../test_first_batch/imagenet32.yaml | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml diff --git a/project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml b/project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml new file mode 100644 index 00000000..2540dc74 --- /dev/null +++ b/project/datamodules/datamodules_test/test_first_batch/imagenet32.yaml @@ -0,0 +1,21 @@ +'0': + device: cpu + hash: -8533209956811673698 + max: 2.64 + mean: 0.014 + min: -2.118 + shape: + - 64 + - 3 + - 32 + - 32 + sum: 2763.33 +'1': + device: cpu + hash: -8357971836707848708 + max: 993 + mean: 487.125 + min: 1 + shape: + - 64 + sum: 31176 From 2a3d5bba2258f3466fc7c5b6fea13fb14bd75aff Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 20 Jun 2024 13:31:40 +0000 Subject: [PATCH 22/27] Move VisionDataModule, fix import issues Signed-off-by: Fabrice Normandin --- project/algorithms/bases/algorithm_test.py | 15 ++------------- project/conftest.py | 10 ++-------- project/datamodules/__init__.py | 2 +- .../datamodules/image_classification/base.py | 2 +- .../image_classification/imagenet.py | 14 ++++---------- .../image_classification/imagenet32.py | 3 +-- .../datamodules/{vision/base.py => vision.py} | 3 +-- project/datamodules/vision/__init__.py | 0 project/utils/testutils.py | 18 +++++++++++++++--- 9 files changed, 27 insertions(+), 40 deletions(-) rename project/datamodules/{vision/base.py => vision.py} (99%) delete mode 100644 project/datamodules/vision/__init__.py diff --git a/project/algorithms/bases/algorithm_test.py b/project/algorithms/bases/algorithm_test.py index 32583d5d..8d9a69ef 100644 --- a/project/algorithms/bases/algorithm_test.py +++ b/project/algorithms/bases/algorithm_test.py @@ -1,10 +1,8 @@ from __future__ import annotations -import contextlib import copy import inspect import operator -import random import sys import typing from collections.abc import Callable, Sequence @@ -12,7 +10,6 @@ from pathlib import Path from typing import Any, ClassVar, Generic, Literal, TypeVar -import numpy as np import pytest import torch from lightning import Callback, LightningDataModule, LightningModule, Trainer @@ -28,7 +25,7 @@ from project.datamodules.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.experiment import ( instantiate_datamodule, instantiate_network, @@ -38,6 +35,7 @@ from project.utils.testutils import ( default_marks_for_config_combinations, default_marks_for_config_name, + fork_rng, get_all_datamodule_names_params, get_all_network_names, get_type_for_config_name, @@ -268,15 +266,6 @@ def test_experiment_reproducible_given_seed( overrides_1 = all_overrides + [f"++trainer.default_root_dir={tmp_path_1}"] overrides_2 = all_overrides + [f"++trainer.default_root_dir={tmp_path_2}"] - @contextlib.contextmanager - def fork_rng(): - with torch.random.fork_rng(): - random_state = random.getstate() - np_random_state = np.random.get_state() - yield - np.random.set_state(np_random_state) - random.setstate(random_state) - with ( fork_rng(), setup_hydra_for_tests_and_compose(overrides_1, tmp_path=tmp_path_1) as config_1, diff --git a/project/conftest.py b/project/conftest.py index 993a0f35..5cf5a914 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -24,7 +24,7 @@ from project.datamodules.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.experiment import ( instantiate_algorithm, instantiate_datamodule, @@ -34,7 +34,6 @@ setup_experiment, setup_logging, ) -from project.utils.env_vars import DATA_DIR from project.utils.hydra_utils import resolve_dictconfig from project.utils.testutils import ( default_marks_for_config_combinations, @@ -138,11 +137,6 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[Function]): items.pop(index) -@pytest.fixture(scope="session") -def data_dir() -> Path: - return DATA_DIR - - @pytest.fixture(autouse=True) def seed(request: pytest.FixtureRequest): """Fixture that seeds everything for reproducibility and yields the random seed used.""" @@ -342,7 +336,7 @@ def algorithm_name(request: pytest.FixtureRequest) -> str | None: @pytest.fixture(scope="session") def datamodule_name(request: pytest.FixtureRequest) -> str | None: - datamodule_config_name = getattr(request, "param") + datamodule_config_name = getattr(request, "param", None) if datamodule_config_name: _add_default_marks_for_config_name(datamodule_config_name, request) return datamodule_config_name diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index 1b0a558e..40eb5928 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -5,7 +5,7 @@ from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization from .image_classification.inaturalist import INaturalistDataModule from .image_classification.mnist import MNISTDataModule -from .vision.base import VisionDataModule +from .vision import VisionDataModule __all__ = [ "cifar10_normalization", diff --git a/project/datamodules/image_classification/base.py b/project/datamodules/image_classification/base.py index 8b392437..331cfbe6 100644 --- a/project/datamodules/image_classification/base.py +++ b/project/datamodules/image_classification/base.py @@ -2,7 +2,7 @@ from torch import Tensor -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.utils.types import C, H, W # todo: decide if this should be a protocol or an actual base class (currently a base class). diff --git a/project/datamodules/image_classification/imagenet.py b/project/datamodules/image_classification/imagenet.py index f80356d5..60554038 100644 --- a/project/datamodules/image_classification/imagenet.py +++ b/project/datamodules/image_classification/imagenet.py @@ -21,7 +21,7 @@ from torchvision.models.resnet import ResNet152_Weights from torchvision.transforms import v2 as transform_lib -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.utils.env_vars import DATA_DIR, NUM_WORKERS from project.utils.types import C, H, StageStr, W from project.utils.types.protocols import Module @@ -30,9 +30,7 @@ def imagenet_normalization(): - return transform_lib.Normalize( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + return transform_lib.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) type ClassIndex = int @@ -140,9 +138,7 @@ def setup(self, stage: StageStr | None = None) -> None: logger.debug(f"Setup ImageNet datamodule for {stage=}") super().setup(stage) - def _split_dataset( - self, dataset: ImageNet, train: bool = True - ) -> torch.utils.data.Dataset: + def _split_dataset(self, dataset: ImageNet, train: bool = True) -> torch.utils.data.Dataset: class_item_indices: dict[ClassIndex, list[ImageIndex]] = defaultdict(list) for dataset_index, y in enumerate(dataset.targets): class_item_indices[y].append(dataset_index) @@ -345,9 +341,7 @@ def _extract_train_archive( *, train_archive: Path, train_dir: Path, previously_extracted_dirs_file: Path ) -> None: # The ImageNet train archive is a tarfile of tarfiles (one for each class). - logger.debug( - "Extracting the ImageNet train archive using Olexa's tar magic in python form..." - ) + logger.debug("Extracting the ImageNet train archive using Olexa's tar magic in python form...") train_dir.mkdir(exist_ok=True, parents=True) # Save a small text file or something that tells us which subdirs are diff --git a/project/datamodules/image_classification/imagenet32.py b/project/datamodules/image_classification/imagenet32.py index 0ab5a744..c66eb0f3 100644 --- a/project/datamodules/image_classification/imagenet32.py +++ b/project/datamodules/image_classification/imagenet32.py @@ -17,11 +17,10 @@ from torchvision.datasets import VisionDataset from torchvision.transforms import v2 as transforms +from project.datamodules.vision import VisionDataModule from project.utils.env_vars import DATA_DIR, SCRATCH from project.utils.types import C, H, StageStr, W -from ..vision.base import VisionDataModule - logger = getLogger(__name__) diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision.py similarity index 99% rename from project/datamodules/vision/base.py rename to project/datamodules/vision.py index b3fa7542..c2d0f6b1 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision.py @@ -16,8 +16,7 @@ from project.utils.env_vars import DATA_DIR, NUM_WORKERS from project.utils.types import C, H, StageStr, W - -from ...utils.types.protocols import DataModule +from project.utils.types.protocols import DataModule logger = get_logger(__name__) diff --git a/project/datamodules/vision/__init__.py b/project/datamodules/vision/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 1587b8d5..617c999e 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -2,11 +2,13 @@ from __future__ import annotations +import contextlib import copy import dataclasses import hashlib import importlib import os +import random from collections.abc import Mapping, Sequence from contextlib import contextmanager from logging import getLogger as get_logger @@ -15,6 +17,7 @@ import hydra.errors import hydra_zen +import numpy as np import pytest import torch import yaml @@ -28,7 +31,7 @@ from project.datamodules.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision.base import VisionDataModule +from project.datamodules.vision import VisionDataModule from project.experiment import instantiate_trainer from project.utils.env_vars import NETWORK_DIR from project.utils.hydra_utils import get_attr, get_outer_class @@ -620,8 +623,17 @@ def assert_no_nans_in_params_or_grads(module: nn.Module): assert not torch.isnan(param.grad).any(), name +@contextlib.contextmanager +def fork_rng(): + with torch.random.fork_rng(): + random_state = random.getstate() + np_random_state = np.random.get_state() + yield + np.random.set_state(np_random_state) + random.setstate(random_state) + + @contextmanager def seeded(seed: int = 42): - with torch.random.fork_rng(): - torch.random.manual_seed(seed) + with fork_rng(): yield From 7480ac9b169f9d5ddc4f5e4dbc08a33c55d910e3 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 20 Jun 2024 13:39:20 +0000 Subject: [PATCH 23/27] Update dependencies Signed-off-by: Fabrice Normandin --- pdm.lock | 91 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/pdm.lock b/pdm.lock index 7edd8611..80be51c6 100644 --- a/pdm.lock +++ b/pdm.lock @@ -200,7 +200,7 @@ files = [ [[package]] name = "brax" -version = "0.10.4" +version = "0.10.5" summary = "A differentiable physics engine written in JAX." groups = ["default"] dependencies = [ @@ -222,6 +222,7 @@ dependencies = [ "mujoco-mjx", "numpy", "optax", + "orbax-checkpoint", "pytinyrenderer", "scipy", "tensorboardX", @@ -229,8 +230,8 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "brax-0.10.4-py3-none-any.whl", hash = "sha256:c47affa423ed0b2a987baef2553eeb84e701d52bfaa72695421d8b4ed9a826a5"}, - {file = "brax-0.10.4.tar.gz", hash = "sha256:6646bb5e280d3de2301f4908f236a14333817bdba5c7ec7faf38d4e8a627aec8"}, + {file = "brax-0.10.5-py3-none-any.whl", hash = "sha256:304fe6e5e266e42a18f197f2b7b6a9bb03a87bd97928e385c51c874b56f95866"}, + {file = "brax-0.10.5.tar.gz", hash = "sha256:e7563130c2b08bf0c9453d87602126732f20afc4624cb8574b3577fa62fdbcec"}, ] [[package]] @@ -2172,7 +2173,7 @@ files = [ [[package]] name = "pytest" -version = "8.2.1" +version = "8.2.2" requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" groups = ["default", "dev"] @@ -2183,8 +2184,8 @@ dependencies = [ "pluggy<2.0,>=1.5", ] files = [ - {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"}, - {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"}, + {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, + {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] [[package]] @@ -2460,28 +2461,28 @@ files = [ [[package]] name = "ruff" -version = "0.4.6" +version = "0.4.9" requires_python = ">=3.7" summary = "An extremely fast Python linter and code formatter, written in Rust." groups = ["dev"] files = [ - {file = "ruff-0.4.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ef995583a038cd4a7edf1422c9e19118e2511b8ba0b015861b4abd26ec5367c5"}, - {file = "ruff-0.4.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:602ebd7ad909eab6e7da65d3c091547781bb06f5f826974a53dbe563d357e53c"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f9ced5cbb7510fd7525448eeb204e0a22cabb6e99a3cb160272262817d49786"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04a80acfc862e0e1630c8b738e70dcca03f350bad9e106968a8108379e12b31f"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be47700ecb004dfa3fd4dcdddf7322d4e632de3c06cd05329d69c45c0280e618"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1ff930d6e05f444090a0139e4e13e1e2e1f02bd51bb4547734823c760c621e79"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f13410aabd3b5776f9c5699f42b37a3a348d65498c4310589bc6e5c548dc8a2f"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0cf5cc02d3ae52dfb0c8a946eb7a1d6ffe4d91846ffc8ce388baa8f627e3bd50"}, - {file = "ruff-0.4.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea3424793c29906407e3cf417f28fc33f689dacbbadfb52b7e9a809dd535dcef"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1fa8561489fadf483ffbb091ea94b9c39a00ed63efacd426aae2f197a45e67fc"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4d5b914818d8047270308fe3e85d9d7f4a31ec86c6475c9f418fbd1624d198e0"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:4f02284335c766678778475e7698b7ab83abaf2f9ff0554a07b6f28df3b5c259"}, - {file = "ruff-0.4.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:3a6a0a4f4b5f54fff7c860010ab3dd81425445e37d35701a965c0248819dde7a"}, - {file = "ruff-0.4.6-py3-none-win32.whl", hash = "sha256:9018bf59b3aa8ad4fba2b1dc0299a6e4e60a4c3bc62bbeaea222679865453062"}, - {file = "ruff-0.4.6-py3-none-win_amd64.whl", hash = "sha256:a769ae07ac74ff1a019d6bd529426427c3e30d75bdf1e08bb3d46ac8f417326a"}, - {file = "ruff-0.4.6-py3-none-win_arm64.whl", hash = "sha256:735a16407a1a8f58e4c5b913ad6102722e80b562dd17acb88887685ff6f20cf6"}, - {file = "ruff-0.4.6.tar.gz", hash = "sha256:a797a87da50603f71e6d0765282098245aca6e3b94b7c17473115167d8dfb0b7"}, + {file = "ruff-0.4.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b262ed08d036ebe162123170b35703aaf9daffecb698cd367a8d585157732991"}, + {file = "ruff-0.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:98ec2775fd2d856dc405635e5ee4ff177920f2141b8e2d9eb5bd6efd50e80317"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4555056049d46d8a381f746680db1c46e67ac3b00d714606304077682832998e"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e91175fbe48f8a2174c9aad70438fe9cb0a5732c4159b2a10a3565fea2d94cde"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e8e7b95673f22e0efd3571fb5b0cf71a5eaaa3cc8a776584f3b2cc878e46bff"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:2d45ddc6d82e1190ea737341326ecbc9a61447ba331b0a8962869fcada758505"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:78de3fdb95c4af084087628132336772b1c5044f6e710739d440fc0bccf4d321"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:06b60f91bfa5514bb689b500a25ba48e897d18fea14dce14b48a0c40d1635893"}, + {file = "ruff-0.4.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88bffe9c6a454bf8529f9ab9091c99490578a593cc9f9822b7fc065ee0712a06"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:673bddb893f21ab47a8334c8e0ea7fd6598ecc8e698da75bcd12a7b9d0a3206e"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8c1aff58c31948cc66d0b22951aa19edb5af0a3af40c936340cd32a8b1ab7438"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:784d3ec9bd6493c3b720a0b76f741e6c2d7d44f6b2be87f5eef1ae8cc1d54c84"}, + {file = "ruff-0.4.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:732dd550bfa5d85af8c3c6cbc47ba5b67c6aed8a89e2f011b908fc88f87649db"}, + {file = "ruff-0.4.9-py3-none-win32.whl", hash = "sha256:8064590fd1a50dcf4909c268b0e7c2498253273309ad3d97e4a752bb9df4f521"}, + {file = "ruff-0.4.9-py3-none-win_amd64.whl", hash = "sha256:e0a22c4157e53d006530c902107c7f550b9233e9706313ab57b892d7197d8e52"}, + {file = "ruff-0.4.9-py3-none-win_arm64.whl", hash = "sha256:5d5460f789ccf4efd43f265a58538a2c24dbce15dbf560676e430375f20a8198"}, + {file = "ruff-0.4.9.tar.gz", hash = "sha256:f1cb0828ac9533ba0135d148d214e284711ede33640465e706772645483427e3"}, ] [[package]] @@ -2731,18 +2732,18 @@ name = "tensor-regression" version = "0.0.2.post3.dev0" requires_python = "<4.0,>=3.11" git = "https://www.github.com/lebrice/tensor_regression" -revision = "2b15f9312fe8891f0c617b5cbce1ba757d514a0a" +revision = "7b3a07ae924eaeacde6ebeade2efcd7f8ce526d5" summary = "A small wrapper around pytest_regressions for Tensors" groups = ["default", "dev"] dependencies = [ "numpy<2.0.0,>=1.26.4", "pytest-regressions<3.0.0,>=2.5.0", - "torch<3.0.0,>=2.3.1", + "torch<3.0.0,>=2.0.0", ] [[package]] name = "tensorboard" -version = "2.16.2" +version = "2.17.0" requires_python = ">=3.9" summary = "TensorBoard lets you watch Tensors Flow" groups = ["default"] @@ -2751,14 +2752,14 @@ dependencies = [ "grpcio>=1.48.2", "markdown>=2.6.8", "numpy>=1.12.0", - "protobuf!=4.24.0,>=3.19.6", + "protobuf!=4.24.0,<5.0.0,>=3.19.6", "setuptools>=41.0.0", "six>1.9", "tensorboard-data-server<0.8.0,>=0.7.0", "werkzeug>=1.0.1", ] files = [ - {file = "tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45"}, + {file = "tensorboard-2.17.0-py3-none-any.whl", hash = "sha256:859a499a9b1fb68a058858964486627100b71fcb21646861c61d31846a6478fb"}, ] [[package]] @@ -2853,16 +2854,16 @@ files = [ [[package]] name = "torch-jax-interop" version = "0.0.4.post7.dev0" -requires_python = ">=3.11,<4.0" +requires_python = "<4.0,>=3.11" git = "https://www.github.com/lebrice/torch_jax_interop" -revision = "7f0c72fe19d8bd4bd957f20dd90d77acd8178bd4" +revision = "3a4261f949d739cfe684280203137114b169e70e" summary = "Utility to convert Tensors from Jax to Torch and vice-versa" groups = ["default"] dependencies = [ - "flax<0.9.0,>=0.8.4", - "jax[cuda12]<0.5.0,>=0.4.28", - "pytorch2jax<0.2.0,>=0.1.0", - "torch<3.0.0,>=2.3.0", + "flax<1.0.0,>=0.8.4", + "jax[cuda12]<1.0.0,>=0.4.28", + "pytorch2jax<1.0.0,>=0.1.0", + "torch<3.0.0,>=2.0.0", ] [[package]] @@ -3000,7 +3001,7 @@ files = [ [[package]] name = "wandb" -version = "0.17.0" +version = "0.17.2" requires_python = ">=3.7" summary = "A CLI and library for interacting with the Weights & Biases API." groups = ["default"] @@ -3009,8 +3010,8 @@ dependencies = [ "docker-pycreds>=0.4.0", "gitpython!=3.1.29,>=1.0.0", "platformdirs", - "protobuf!=4.21.0,<5,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"", - "protobuf!=4.21.0,<5,>=3.19.0; sys_platform != \"linux\"", + "protobuf!=4.21.0,<6,>=3.19.0; python_version > \"3.9\" and sys_platform == \"linux\"", + "protobuf!=4.21.0,<6,>=3.19.0; sys_platform != \"linux\"", "psutil>=5.0.0", "pyyaml", "requests<3,>=2.0.0", @@ -3019,13 +3020,13 @@ dependencies = [ "setuptools", ] files = [ - {file = "wandb-0.17.0-py3-none-any.whl", hash = "sha256:b1b056b4cad83b00436cb76049fd29ecedc6045999dcaa5eba40db6680960ac2"}, - {file = "wandb-0.17.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:e1e6f04e093a6a027dcb100618ca23b122d032204b2ed4c62e4e991a48041a6b"}, - {file = "wandb-0.17.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:feeb60d4ff506d2a6bc67f953b310d70b004faa789479c03ccd1559c6f1a9633"}, - {file = "wandb-0.17.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bed8a3dd404a639e6bf5fea38c6efe2fb98d416ff1db4fb51be741278ed328"}, - {file = "wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56a1dd6e0e635cba3f6ed30b52c71739bdc2a3e57df155619d2d80ee952b4201"}, - {file = "wandb-0.17.0-py3-none-win32.whl", hash = "sha256:1f692d3063a0d50474022cfe6668e1828260436d1cd40827d1e136b7f730c74c"}, - {file = "wandb-0.17.0-py3-none-win_amd64.whl", hash = "sha256:ab582ca0d54d52ef5b991de0717350b835400d9ac2d3adab210022b68338d694"}, + {file = "wandb-0.17.2-py3-none-any.whl", hash = "sha256:4bd351be28cea87730365856cfaa72f72ceb787accc21bad359dde5aa9c4356d"}, + {file = "wandb-0.17.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:638353a2d702caedd304a5f1e526ef93a291c984c109fcb444262a57aeaacec9"}, + {file = "wandb-0.17.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:824e33ca77af87f87a9cf1122acba164da5bf713adc9d67332bc686028921ec9"}, + {file = "wandb-0.17.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:032ca5939008643349af178a8b66b8047a1eefcb870c4c4a86e22acafde6470f"}, + {file = "wandb-0.17.2-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9558bab47a0c8ac4f22cfa2d43f91d1bc1f75d4255629286db674fe49fcd30e5"}, + {file = "wandb-0.17.2-py3-none-win32.whl", hash = "sha256:4bc176e3c81be216dc889fcd098341eb17a14b04e080d4343ce3f0b1740abfc1"}, + {file = "wandb-0.17.2-py3-none-win_amd64.whl", hash = "sha256:62cd707f38b5711971729dae80343b8c35f6003901e690166cc6d526187a9785"}, ] [[package]] From 662212dc37edd096d491a27a26748916619ce7f1 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 20 Jun 2024 17:06:52 +0000 Subject: [PATCH 24/27] Find working combination of pydantic/lightning Signed-off-by: Fabrice Normandin --- pdm.lock | 788 +++++++++++++------------------------------------ pyproject.toml | 4 +- 2 files changed, 216 insertions(+), 576 deletions(-) diff --git a/pdm.lock b/pdm.lock index 80be51c6..1eb350cd 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:df6ae2c882d59bd85dcb8e915205e54db01eab36f4a4910af80acaa945817388" +content_hash = "sha256:2e4f9911bfebcfc3a32ec5f7c9257db49d5c51fde8087009c4b966873563c79c" [[package]] name = "absl-py" @@ -65,14 +65,14 @@ files = [ ] [[package]] -name = "ansicon" -version = "1.89.0" -summary = "Python wrapper for loading Jason Hood's ANSICON" +name = "annotated-types" +version = "0.7.0" +requires_python = ">=3.8" +summary = "Reusable constraint types to use with typing.Annotated" groups = ["default"] -marker = "platform_system == \"Windows\"" files = [ - {file = "ansicon-1.89.0-py2.py3-none-any.whl", hash = "sha256:f1def52d17f65c2c9682cf8370c03f541f410c1752d6a14029f97318e4b9dfec"}, - {file = "ansicon-1.89.0.tar.gz", hash = "sha256:e4d039def5768a47e4afec8e89e83ec3ae5a26bf00ad851f914d1240b444d2b1"}, + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] [[package]] @@ -84,36 +84,6 @@ files = [ {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"}, ] -[[package]] -name = "anyio" -version = "4.4.0" -requires_python = ">=3.8" -summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["default"] -dependencies = [ - "idna>=2.8", - "sniffio>=1.1", -] -files = [ - {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, - {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, -] - -[[package]] -name = "arrow" -version = "1.3.0" -requires_python = ">=3.8" -summary = "Better dates & times for Python" -groups = ["default"] -dependencies = [ - "python-dateutil>=2.7.0", - "types-python-dateutil>=2.8.10", -] -files = [ - {file = "arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80"}, - {file = "arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85"}, -] - [[package]] name = "attrs" version = "23.2.0" @@ -139,22 +109,6 @@ files = [ {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, ] -[[package]] -name = "blessed" -version = "1.20.0" -requires_python = ">=2.7" -summary = "Easy, practical library for making terminal apps, by providing an elegant, well-documented interface to Colors, Keyboard input, and screen Positioning capabilities." -groups = ["default"] -dependencies = [ - "jinxed>=1.1.0; platform_system == \"Windows\"", - "six>=1.9.0", - "wcwidth>=0.1.4", -] -files = [ - {file = "blessed-1.20.0-py2.py3-none-any.whl", hash = "sha256:0c542922586a265e699188e52d5f5ac5ec0dd517e5a1041d90d2bbf23f906058"}, - {file = "blessed-1.20.0.tar.gz", hash = "sha256:2cdd67f8746e048f00df47a2880f4d6acbcdb399031b604e34ba8f71d5787680"}, -] - [[package]] name = "blinker" version = "1.8.2" @@ -166,38 +120,6 @@ files = [ {file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"}, ] -[[package]] -name = "boto3" -version = "1.34.116" -requires_python = ">=3.8" -summary = "The AWS SDK for Python" -groups = ["default"] -dependencies = [ - "botocore<1.35.0,>=1.34.116", - "jmespath<2.0.0,>=0.7.1", - "s3transfer<0.11.0,>=0.10.0", -] -files = [ - {file = "boto3-1.34.116-py3-none-any.whl", hash = "sha256:e7f5ab2d1f1b90971a2b9369760c2c6bae49dae98c084a5c3f5c78e3968ace15"}, - {file = "boto3-1.34.116.tar.gz", hash = "sha256:53cb8aeb405afa1cd2b25421e27a951aeb568026675dec020587861fac96ac87"}, -] - -[[package]] -name = "botocore" -version = "1.34.116" -requires_python = ">=3.8" -summary = "Low-level, data-driven core of boto 3." -groups = ["default"] -dependencies = [ - "jmespath<2.0.0,>=0.7.1", - "python-dateutil<3.0.0,>=2.1", - "urllib3!=2.2.0,<3,>=1.25.4; python_version >= \"3.10\"", -] -files = [ - {file = "botocore-1.34.116-py3-none-any.whl", hash = "sha256:ec4d42c816e9b2d87a2439ad277e7dda16a4a614ef6839cf66f4c1a58afa547c"}, - {file = "botocore-1.34.116.tar.gz", hash = "sha256:269cae7ba99081519a9f87d7298e238d9e68ba94eb4f8ddfa906224c34cb8b6c"}, -] - [[package]] name = "brax" version = "0.10.5" @@ -236,13 +158,13 @@ files = [ [[package]] name = "certifi" -version = "2024.2.2" +version = "2024.6.2" requires_python = ">=3.6" summary = "Python package for providing Mozilla's CA Bundle." groups = ["default"] files = [ - {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, - {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, + {file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"}, + {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, ] [[package]] @@ -322,7 +244,7 @@ version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." groups = ["default", "dev"] -marker = "platform_system == \"Windows\" or sys_platform == \"win32\"" +marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -425,20 +347,6 @@ files = [ {file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"}, ] -[[package]] -name = "croniter" -version = "1.3.15" -requires_python = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -summary = "croniter provides iteration for datetime object with cron like format" -groups = ["default"] -dependencies = [ - "python-dateutil", -] -files = [ - {file = "croniter-1.3.15-py2.py3-none-any.whl", hash = "sha256:f17f877be1d93b9e3191151584a19d8b367b017ab0febc8c5472b9300da61c4c"}, - {file = "croniter-1.3.15.tar.gz", hash = "sha256:924a38fda88f675ec6835667e1d32ac37ff0d65509c2152729d16ff205e32a65"}, -] - [[package]] name = "cycler" version = "0.12.1" @@ -450,20 +358,6 @@ files = [ {file = "cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c"}, ] -[[package]] -name = "dateutils" -version = "0.6.12" -summary = "Various utilities for working with date and datetime objects" -groups = ["default"] -dependencies = [ - "python-dateutil", - "pytz", -] -files = [ - {file = "dateutils-0.6.12-py2.py3-none-any.whl", hash = "sha256:f33b6ab430fa4166e7e9cb8b21ee9f6c9843c48df1a964466f52c79b2a8d53b3"}, - {file = "dateutils-0.6.12.tar.gz", hash = "sha256:03dd90bcb21541bd4eb4b013637e4f1b5f944881c46cc6e4b67a6059e370e3f1"}, -] - [[package]] name = "decorator" version = "4.4.2" @@ -475,20 +369,6 @@ files = [ {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, ] -[[package]] -name = "deepdiff" -version = "7.0.1" -requires_python = ">=3.8" -summary = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." -groups = ["default"] -dependencies = [ - "ordered-set<4.2.0,>=4.1.0", -] -files = [ - {file = "deepdiff-7.0.1-py3-none-any.whl", hash = "sha256:447760081918216aa4fd4ca78a4b6a848b81307b2ea94c810255334b759e1dc3"}, - {file = "deepdiff-7.0.1.tar.gz", hash = "sha256:260c16f052d4badbf60351b4f77e8390bee03a0b516246f6839bc813fb429ddf"}, -] - [[package]] name = "dm-env" version = "1.6" @@ -535,40 +415,36 @@ files = [ ] [[package]] -name = "editor" -version = "1.6.6" -requires_python = ">=3.8" -summary = "🖋 Open the default text editor 🖋" +name = "docstring-parser" +version = "0.16" +requires_python = ">=3.6,<4.0" +summary = "Parse Python docstrings in reST, Google and Numpydoc format" groups = ["default"] -dependencies = [ - "runs", - "xmod", -] files = [ - {file = "editor-1.6.6-py3-none-any.whl", hash = "sha256:e818e6913f26c2a81eadef503a2741d7cca7f235d20e217274a009ecd5a74abf"}, - {file = "editor-1.6.6.tar.gz", hash = "sha256:bb6989e872638cd119db9a4fce284cd8e13c553886a1c044c6b8d8a160c871f8"}, + {file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"}, + {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" extras = ["epath", "epy"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.9.0", + "etils==1.9.2", "etils[epy]", "fsspec", "importlib-resources", @@ -577,19 +453,19 @@ dependencies = [ "zipp", ] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" extras = ["epath"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.9.0", + "etils==1.9.2", "etils[epy]", "fsspec", "importlib-resources", @@ -597,24 +473,24 @@ dependencies = [ "zipp", ] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] name = "etils" -version = "1.9.0" +version = "1.9.2" extras = ["epy"] requires_python = ">=3.11" summary = "Collection of common python utils" groups = ["default"] dependencies = [ - "etils==1.9.0", + "etils==1.9.2", "typing-extensions", ] files = [ - {file = "etils-1.9.0-py3-none-any.whl", hash = "sha256:b4b9ea97a888f7c8e07de37d0547e303298f4bb7616143c5f027a99a82a6cd84"}, - {file = "etils-1.9.0.tar.gz", hash = "sha256:5d0f8ddaa8e0e640c685ed7a7fe1fc5c8162533fa12fb945f09ecc539b0b366c"}, + {file = "etils-1.9.2-py3-none-any.whl", hash = "sha256:ecd79de1fbfea9b0d6924756cfa922b05ed3360c45cf2170767da4bee0001d20"}, + {file = "etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379"}, ] [[package]] @@ -638,30 +514,15 @@ files = [ {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, ] -[[package]] -name = "fastapi" -version = "0.88.0" -requires_python = ">=3.7" -summary = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" -groups = ["default"] -dependencies = [ - "pydantic!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0,>=1.6.2", - "starlette==0.22.0", -] -files = [ - {file = "fastapi-0.88.0-py3-none-any.whl", hash = "sha256:263b718bb384422fe3d042ffc9a0c8dece5e034ab6586ff034f6b4b1667c3eee"}, - {file = "fastapi-0.88.0.tar.gz", hash = "sha256:915bf304180a0e7c5605ec81097b7d4cd8826ff87a02bb198e336fb9f3b5ff02"}, -] - [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.3" requires_python = ">=3.8" summary = "A platform independent file lock." groups = ["default", "dev"] files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.3-py3-none-any.whl", hash = "sha256:0151273e5b5d6cf753a61ec83b3a9b7d8821c39ae9af9d7ecf2f9e2f17404103"}, + {file = "filelock-3.15.3.tar.gz", hash = "sha256:e1199bf5194a2277273dacd50269f0d87d0682088a3c561c15674ea9005d8635"}, ] [[package]] @@ -856,21 +717,21 @@ files = [ [[package]] name = "grpcio" -version = "1.64.0" +version = "1.64.1" requires_python = ">=3.8" summary = "HTTP/2-based RPC framework" groups = ["default"] files = [ - {file = "grpcio-1.64.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:1ce4cd5a61d4532651079e7aae0fedf9a80e613eed895d5b9743e66b52d15812"}, - {file = "grpcio-1.64.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:650a8150a9b288f40d5b7c1d5400cc11724eae50bd1f501a66e1ea949173649b"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:8de0399b983f8676a7ccfdd45e5b2caec74a7e3cc576c6b1eecf3b3680deda5e"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:46b8b43ba6a2a8f3103f103f97996cad507bcfd72359af6516363c48793d5a7b"}, - {file = "grpcio-1.64.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a54362f03d4dcfae63be455d0a7d4c1403673498b92c6bfe22157d935b57c7a9"}, - {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1f8ea18b928e539046bb5f9c124d717fbf00cc4b2d960ae0b8468562846f5aa1"}, - {file = "grpcio-1.64.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c56c91bd2923ddb6e7ed28ebb66d15633b03e0df22206f22dfcdde08047e0a48"}, - {file = "grpcio-1.64.0-cp312-cp312-win32.whl", hash = "sha256:874c741c8a66f0834f653a69e7e64b4e67fcd4a8d40296919b93bab2ccc780ba"}, - {file = "grpcio-1.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:0da1d921f8e4bcee307aeef6c7095eb26e617c471f8cb1c454fd389c5c296d1e"}, - {file = "grpcio-1.64.0.tar.gz", hash = "sha256:257baf07f53a571c215eebe9679c3058a313fd1d1f7c4eede5a8660108c52d9c"}, + {file = "grpcio-1.64.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:4657d24c8063e6095f850b68f2d1ba3b39f2b287a38242dcabc166453e950c59"}, + {file = "grpcio-1.64.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:62b4e6eb7bf901719fce0ca83e3ed474ae5022bb3827b0a501e056458c51c0a1"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:198908f9b22e2672a998870355e226a725aeab327ac4e6ff3a1399792ece4762"}, + {file = "grpcio-1.64.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b9d0acaa8d835a6566c640f48b50054f422d03e77e49716d4c4e8e279665a1"}, + {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5e42634a989c3aa6049f132266faf6b949ec2a6f7d302dbb5c15395b77d757eb"}, + {file = "grpcio-1.64.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b1a82e0b9b3022799c336e1fc0f6210adc019ae84efb7321d668129d28ee1efb"}, + {file = "grpcio-1.64.1-cp312-cp312-win32.whl", hash = "sha256:55260032b95c49bee69a423c2f5365baa9369d2f7d233e933564d8a47b893027"}, + {file = "grpcio-1.64.1-cp312-cp312-win_amd64.whl", hash = "sha256:c1a786ac592b47573a5bb7e35665c08064a5d77ab88a076eec11f8ae86b3e3f6"}, + {file = "grpcio-1.64.1.tar.gz", hash = "sha256:8d51dd1c59d5fa0f34266b80a3805ec29a1f26425c2a54736133f6d87fc4968a"}, ] [[package]] @@ -937,17 +798,6 @@ files = [ {file = "gymnax-0.0.8.tar.gz", hash = "sha256:81defc17f52a30a84338b3daa574d7a3bb112f2656f45c783a71efe31eea68ff"}, ] -[[package]] -name = "h11" -version = "0.14.0" -requires_python = ">=3.7" -summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default"] -files = [ - {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, - {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, -] - [[package]] name = "hydra-colorlog" version = "1.2.0" @@ -1036,7 +886,7 @@ files = [ [[package]] name = "imageio-ffmpeg" -version = "0.5.0" +version = "0.5.1" requires_python = ">=3.5" summary = "FFMPEG wrapper for Python" groups = ["default"] @@ -1045,12 +895,12 @@ dependencies = [ "setuptools", ] files = [ - {file = "imageio-ffmpeg-0.5.0.tar.gz", hash = "sha256:75c9c45079510cfeb4849a17fcd3edd4f14062ea6b69c5b62695fb2075295c87"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl", hash = "sha256:e9aba9cdd01164a50a4cfb1b825fc8769151a0d3b5b5a7d5d50ff9fcda7eee9c"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:ba55f392ee5db9eb0a6d7699e0060a2edcaa7dbc740ca29671bdc8dbb763ca3b"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9c813be7d6a24236bb68aeab249ea67f5a7fdf7d86988855578247694c42e94a"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-win32.whl", hash = "sha256:c4a3b32fc38d4a26c15582bf12246ddae060932889da5c9da487cc675740039b"}, - {file = "imageio_ffmpeg-0.5.0-py3-none-win_amd64.whl", hash = "sha256:8135f4d146094b62b31721ca53fe943f4134e3578e22015468e3df595217c24b"}, + {file = "imageio-ffmpeg-0.5.1.tar.gz", hash = "sha256:0ed7a9b31f560b0c9d929c5291cd430edeb9bed3ce9a497480e536dd4326484c"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl", hash = "sha256:1460e84712b9d06910c1f7bb524096b0341d4b7844cea6c20e099d0a24e795b1"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-manylinux2010_x86_64.whl", hash = "sha256:5289f75c7f755b499653f3209fea4efd1430cba0e39831c381aad2d458f7a316"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7fa9132a291d5eb28c44553550deb40cbdab831f2a614e55360301a6582eb205"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-win32.whl", hash = "sha256:89efe2c79979d8174ba8476deb7f74d74c331caee3fb2b65ba2883bec0737625"}, + {file = "imageio_ffmpeg-0.5.1-py3-none-win_amd64.whl", hash = "sha256:1521e79e253bedbdd36a547e0cbd94a025ba0b558e17f08fea687d805a0e4698"}, ] [[package]] @@ -1075,22 +925,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "inquirer" -version = "3.2.4" -requires_python = ">=3.8.1" -summary = "Collection of common interactive command line user interfaces, based on Inquirer.js" -groups = ["default"] -dependencies = [ - "blessed>=1.19.0", - "editor>=1.6.0", - "readchar>=3.0.6", -] -files = [ - {file = "inquirer-3.2.4-py3-none-any.whl", hash = "sha256:273a4e4a4345ac1afdb17408d40fc8dccf3485db68203357919468561035a763"}, - {file = "inquirer-3.2.4.tar.gz", hash = "sha256:33b09efc1b742b9d687b540296a8b6a3f773399673321fcc2ab0eb4c109bf9b5"}, -] - [[package]] name = "intel-openmp" version = "2021.4.0" @@ -1189,7 +1023,7 @@ files = [ [[package]] name = "jaxlib" -version = "0.4.28" +version = "0.4.28+cuda12.cudnn89" requires_python = ">=3.9" summary = "XLA library for JAX" groups = ["default"] @@ -1239,31 +1073,6 @@ files = [ {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, ] -[[package]] -name = "jinxed" -version = "1.2.1" -summary = "Jinxed Terminal Library" -groups = ["default"] -marker = "platform_system == \"Windows\"" -dependencies = [ - "ansicon; platform_system == \"Windows\"", -] -files = [ - {file = "jinxed-1.2.1-py2.py3-none-any.whl", hash = "sha256:37422659c4925969c66148c5e64979f553386a4226b9484d910d3094ced37d30"}, - {file = "jinxed-1.2.1.tar.gz", hash = "sha256:30c3f861b73279fea1ed928cfd4dfb1f273e16cd62c8a32acfac362da0f78f3f"}, -] - -[[package]] -name = "jmespath" -version = "1.0.1" -requires_python = ">=3.7" -summary = "JSON Matching Expressions" -groups = ["default"] -files = [ - {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, - {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, -] - [[package]] name = "kiwisolver" version = "1.4.5" @@ -1306,73 +1115,25 @@ files = [ [[package]] name = "lightning" -version = "1.9.0" -requires_python = ">=3.7" -summary = "Use Lightning Apps to build everything from production-ready, multi-cloud ML systems to simple research demos." +version = "2.3.0" +requires_python = ">=3.8" +summary = "The Deep Learning framework to train, deploy, and ship AI products Lightning fast." groups = ["default"] dependencies = [ - "Jinja2<5.0", - "PyYAML<8.0", "PyYAML<8.0,>=5.4", - "arrow<3.0,>=1.2.0", - "beautifulsoup4<6.0,>=4.8.0", - "click<10.0", - "croniter<1.4.0,>=1.3.0", - "dateutils<2.0", - "deepdiff<8.0,>=5.7.0", - "fastapi<0.89.0", - "fsspec<2024.0,>=2022.5.0", - "fsspec[http]<2024.0,>2021.06.0", - "inquirer<5.0,>=2.10.0", - "lightning-cloud<2.0,>=0.5.12", - "lightning-utilities<2.0,>=0.4.2", + "fsspec[http]<2026.0,>=2022.5.0", + "lightning-utilities<2.0,>=0.8.0", "numpy<3.0,>=1.17.2", - "packaging", - "packaging<23.0,>=17.1", - "psutil<7.0", - "pydantic<3.0", - "requests<4.0", - "rich<15.0", - "starlette<2.0", - "starsessions<2.0,>=1.2.1", - "torch<3.0,>=1.10.0", - "torchmetrics<2.0,>=0.7.0", + "packaging<25.0,>=20.0", + "pytorch-lightning", + "torch<4.0,>=2.0.0", + "torchmetrics<3.0,>=0.7.0", "tqdm<6.0,>=4.57.0", - "traitlets<7.0,>=5.3.0", - "typing-extensions<6.0,>=4.0.0", - "urllib3<3.0", - "uvicorn<2.0", - "websocket-client<3.0", - "websockets<12.0", -] -files = [ - {file = "lightning-1.9.0-py3-none-any.whl", hash = "sha256:27db661b37c3581fb3467016cbfcaf39aee77e52a12c32344b0c8da1d1e9e311"}, - {file = "lightning-1.9.0.tar.gz", hash = "sha256:d002270e2cd6bdf239d6605f8ec7f6f79bd2ec4eb5e7758b38ca36c57d4d1fdf"}, -] - -[[package]] -name = "lightning-cloud" -version = "0.5.69" -requires_python = ">=3.7.0" -summary = "Lightning Cloud" -groups = ["default"] -dependencies = [ - "boto3", - "click", - "fastapi", - "protobuf", - "pyjwt", - "python-multipart", - "requests", - "rich", - "six", - "urllib3", - "uvicorn", - "websocket-client", + "typing-extensions<6.0,>=4.4.0", ] files = [ - {file = "lightning_cloud-0.5.69-py3-none-any.whl", hash = "sha256:8e26b534c3970ea939d37c284e9de5d0c880339a49d18c9b9181c0e093f95fd1"}, - {file = "lightning_cloud-0.5.69.tar.gz", hash = "sha256:0baeef05c06a6d89c482abea1826cc3e3bec48901d10cc2749f39b344e6f1dc3"}, + {file = "lightning-2.3.0-py3-none-any.whl", hash = "sha256:ed66c2053be1295c8452b996b719badf5a26a0652607c121103dfdd5d2dccfae"}, + {file = "lightning-2.3.0.tar.gz", hash = "sha256:4bb4d6e3650d2d5f544ad60853a22efc4e164aa71b9596d13f0454b29df05130"}, ] [[package]] @@ -1584,7 +1345,7 @@ files = [ [[package]] name = "mujoco" -version = "3.1.5" +version = "3.1.6" requires_python = ">=3.8" summary = "MuJoCo Physics Simulator" groups = ["default"] @@ -1596,17 +1357,17 @@ dependencies = [ "pyopengl", ] files = [ - {file = "mujoco-3.1.5-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:0a78079b07e63d04f2985684ccd3a9937badba4cf51432662ff818b092442dbc"}, - {file = "mujoco-3.1.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4145c6277a1e71000a54c0bfef337c885a57452c5f0aa7cddf4b41932b639f41"}, - {file = "mujoco-3.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20bb70bfee28e026efc71f6872871c689fa2eaecc54d019ae1a21362453619cd"}, - {file = "mujoco-3.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f93bf770c3c963efe03c27b34ca59015e27ae70cdd4272a8312e583f52dbf40"}, - {file = "mujoco-3.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:8b139b1950ad52924e8666561414dd8f4f3f69f89364f1d0304371839be9264e"}, - {file = "mujoco-3.1.5.tar.gz", hash = "sha256:9099ba6001341cc9e38b7b94b8ef7a67346c7638fa3e94f520743a357891f296"}, + {file = "mujoco-3.1.6-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:dc0ab85bcda35b2d87df91b7a13152e970a7108d87ef811f28dc32b2dbfb6754"}, + {file = "mujoco-3.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:37a41c5558bd8823da8b2822d2dd941a4c57ee11bf56be5e77ee157c0e5552a1"}, + {file = "mujoco-3.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92e839b3a3758a0010673ec954a1728ce076be923f868d37739040b029489544"}, + {file = "mujoco-3.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07d3b8c270ba9ae5c87e8e37061277ccc0d46767959b68f2a5c5c1e065213021"}, + {file = "mujoco-3.1.6-cp312-cp312-win_amd64.whl", hash = "sha256:49a6b3f88446686aebd345b12d1ec38259701215c7db355725499be9c0e53ef0"}, + {file = "mujoco-3.1.6.tar.gz", hash = "sha256:7cf8887526f071e7411dc02ce1cd665e39b4b6083fdff49fe1348a82d2314651"}, ] [[package]] name = "mujoco-mjx" -version = "3.1.5" +version = "3.1.6" requires_python = ">=3.8" summary = "MuJoCo XLA (MJX)" groups = ["default"] @@ -1615,13 +1376,13 @@ dependencies = [ "etils[epath]", "jax", "jaxlib", - "mujoco>=3.1.5.dev0", + "mujoco>=3.1.6.dev0", "scipy", "trimesh", ] files = [ - {file = "mujoco_mjx-3.1.5-py3-none-any.whl", hash = "sha256:4fc54e10c0cb811fd97584222a00ce9fa433f79d7ce46a8d7b22c8a054c35238"}, - {file = "mujoco_mjx-3.1.5.tar.gz", hash = "sha256:ee6b409d694a0a34ab93803089e3c1297ed91ae6a9461661cd1d80a9f0565880"}, + {file = "mujoco_mjx-3.1.6-py3-none-any.whl", hash = "sha256:0392975c610a8cbd8ad71ba7d7f524fccdb28bacf041998fea34370dd83d46a3"}, + {file = "mujoco_mjx-3.1.6.tar.gz", hash = "sha256:22f70227c3b7ee94b9e89a706c7a9387ba6f34219ecce5a2dadffac225c6637a"}, ] [[package]] @@ -1722,6 +1483,7 @@ requires_python = ">=3" summary = "CUDA nvcc" groups = ["default"] files = [ + {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:dcea4f7fa223ac32ad40503499cec117e5543a1b34bb91a886049821bfa75304"}, {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8347e2458c99eb9db3c392035c1781798f2593d495554106cf45502eeabc1a10"}, {file = "nvidia_cuda_nvcc_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:616cd3280a05657d1e40d4985058bbd4c88384b92c88a7c30228643abe7465f2"}, ] @@ -1833,6 +1595,7 @@ requires_python = ">=3" summary = "Nvidia JIT LTO Library" groups = ["default", "dev"] files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] @@ -1898,14 +1661,14 @@ files = [ [[package]] name = "orbax-checkpoint" -version = "0.5.15" +version = "0.5.17" requires_python = ">=3.9" summary = "Orbax Checkpoint" groups = ["default"] dependencies = [ "absl-py", "etils[epath,epy]", - "jax>=0.4.9", + "jax>=0.4.25", "jaxlib", "msgpack", "nest-asyncio", @@ -1916,30 +1679,19 @@ dependencies = [ "typing-extensions", ] files = [ - {file = "orbax_checkpoint-0.5.15-py3-none-any.whl", hash = "sha256:658dd89bc925cecc584d89eaa19af9a7e16e3371377907eb713fbd59b85262e4"}, - {file = "orbax_checkpoint-0.5.15.tar.gz", hash = "sha256:15195e8d1b381b56f23a62a25599a3644f5d08655fa64f60bb1b938b8ffe7ef3"}, -] - -[[package]] -name = "ordered-set" -version = "4.1.0" -requires_python = ">=3.7" -summary = "An OrderedSet is a custom MutableSet that remembers its order, so that every" -groups = ["default"] -files = [ - {file = "ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8"}, - {file = "ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562"}, + {file = "orbax_checkpoint-0.5.17-py3-none-any.whl", hash = "sha256:212a29bd43c368ba4b62b0c12d565b56bf50ce6df904aedb6ab379a8a3206fd9"}, + {file = "orbax_checkpoint-0.5.17.tar.gz", hash = "sha256:705574a0b41d935b17312fe36988e72da1f33d57d97732937a77a84c02793f94"}, ] [[package]] name = "packaging" -version = "22.0" -requires_python = ">=3.7" +version = "24.1" +requires_python = ">=3.8" summary = "Core utilities for Python packages" groups = ["default", "dev"] files = [ - {file = "packaging-22.0-py3-none-any.whl", hash = "sha256:957e2148ba0e1a3b282772e791ef1d8083648bc131c8ab0c1feba110ce1146c3"}, - {file = "packaging-22.0.tar.gz", hash = "sha256:2198ec20bd4c017b8f9717e00f0c8714076fc2fd93816750ab48e2c41de2cfd3"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -2053,18 +1805,19 @@ files = [ [[package]] name = "psutil" -version = "5.9.8" -requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*" +version = "6.0.0" +requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" summary = "Cross-platform lib for process and system monitoring in Python." groups = ["default"] files = [ - {file = "psutil-5.9.8-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:aee678c8720623dc456fa20659af736241f575d79429a0e5e9cf88ae0605cc81"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cb6403ce6d8e047495a701dc7c5bd788add903f8986d523e3e20b98b733e421"}, - {file = "psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4"}, - {file = "psutil-5.9.8-cp37-abi3-win32.whl", hash = "sha256:bc56c2a1b0d15aa3eaa5a60c9f3f8e3e565303b465dbf57a1b730e7a2b9844e0"}, - {file = "psutil-5.9.8-cp37-abi3-win_amd64.whl", hash = "sha256:8db4c1b57507eef143a15a6884ca10f7c73876cdf5d51e713151c1236a0e68cf"}, - {file = "psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8"}, - {file = "psutil-5.9.8.tar.gz", hash = "sha256:6be126e3225486dff286a8fb9a06246a5253f4c7c53b475ea5f5ac934e64194c"}, + {file = "psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd"}, + {file = "psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132"}, + {file = "psutil-6.0.0-cp37-abi3-win32.whl", hash = "sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d"}, + {file = "psutil-6.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3"}, + {file = "psutil-6.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0"}, + {file = "psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2"}, ] [[package]] @@ -2079,16 +1832,60 @@ files = [ [[package]] name = "pydantic" -version = "1.10.15" -requires_python = ">=3.7" -summary = "Data validation and settings management using python type hints" +version = "2.7.4" +requires_python = ">=3.8" +summary = "Data validation using Python type hints" groups = ["default"] dependencies = [ - "typing-extensions>=4.2.0", + "annotated-types>=0.4.0", + "pydantic-core==2.18.4", + "typing-extensions>=4.6.1", ] files = [ - {file = "pydantic-1.10.15-py3-none-any.whl", hash = "sha256:28e552a060ba2740d0d2aabe35162652c1459a0b9069fe0db7f4ee0e18e74d58"}, - {file = "pydantic-1.10.15.tar.gz", hash = "sha256:ca832e124eda231a60a041da4f013e3ff24949d94a01154b137fc2f2a43c3ffb"}, + {file = "pydantic-2.7.4-py3-none-any.whl", hash = "sha256:ee8538d41ccb9c0a9ad3e0e5f07bf15ed8015b481ced539a1759d8cc89ae90d0"}, + {file = "pydantic-2.7.4.tar.gz", hash = "sha256:0c84efd9548d545f63ac0060c1e4d39bb9b14db8b3c0652338aecc07b5adec52"}, +] + +[[package]] +name = "pydantic-core" +version = "2.18.4" +requires_python = ">=3.8" +summary = "Core functionality for Pydantic validation and serialization" +groups = ["default"] +dependencies = [ + "typing-extensions!=4.7.0,>=4.6.0", +] +files = [ + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6f5c4d41b2771c730ea1c34e458e781b18cc668d194958e0112455fff4e402b2"}, + {file = "pydantic_core-2.18.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2fdf2156aa3d017fddf8aea5adfba9f777db1d6022d392b682d2a8329e087cef"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4748321b5078216070b151d5271ef3e7cc905ab170bbfd27d5c83ee3ec436695"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:847a35c4d58721c5dc3dba599878ebbdfd96784f3fb8bb2c356e123bdcd73f34"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c40d4eaad41f78e3bbda31b89edc46a3f3dc6e171bf0ecf097ff7a0ffff7cb1"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:21a5e440dbe315ab9825fcd459b8814bb92b27c974cbc23c3e8baa2b76890077"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01dd777215e2aa86dfd664daed5957704b769e726626393438f9c87690ce78c3"}, + {file = "pydantic_core-2.18.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4b06beb3b3f1479d32befd1f3079cc47b34fa2da62457cdf6c963393340b56e9"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:564d7922e4b13a16b98772441879fcdcbe82ff50daa622d681dd682175ea918c"}, + {file = "pydantic_core-2.18.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:0eb2a4f660fcd8e2b1c90ad566db2b98d7f3f4717c64fe0a83e0adb39766d5b8"}, + {file = "pydantic_core-2.18.4-cp312-none-win32.whl", hash = "sha256:8b8bab4c97248095ae0c4455b5a1cd1cdd96e4e4769306ab19dda135ea4cdb07"}, + {file = "pydantic_core-2.18.4-cp312-none-win_amd64.whl", hash = "sha256:14601cdb733d741b8958224030e2bfe21a4a881fb3dd6fbb21f071cabd48fa0a"}, + {file = "pydantic_core-2.18.4-cp312-none-win_arm64.whl", hash = "sha256:c1322d7dd74713dcc157a2b7898a564ab091ca6c58302d5c7b4c07296e3fd00f"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:574d92eac874f7f4db0ca653514d823a0d22e2354359d0759e3f6a406db5d55d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1f4d26ceb5eb9eed4af91bebeae4b06c3fb28966ca3a8fb765208cf6b51102ab"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77450e6d20016ec41f43ca4a6c63e9fdde03f0ae3fe90e7c27bdbeaece8b1ed4"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d323a01da91851a4f17bf592faf46149c9169d68430b3146dcba2bb5e5719abc"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43d447dd2ae072a0065389092a231283f62d960030ecd27565672bd40746c507"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:578e24f761f3b425834f297b9935e1ce2e30f51400964ce4801002435a1b41ef"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:81b5efb2f126454586d0f40c4d834010979cb80785173d1586df845a632e4e6d"}, + {file = "pydantic_core-2.18.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ab86ce7c8f9bea87b9d12c7f0af71102acbf5ecbc66c17796cff45dae54ef9a5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:90afc12421df2b1b4dcc975f814e21bc1754640d502a2fbcc6d41e77af5ec312"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:51991a89639a912c17bef4b45c87bd83593aee0437d8102556af4885811d59f5"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:293afe532740370aba8c060882f7d26cfd00c94cae32fd2e212a3a6e3b7bc15e"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b48ece5bde2e768197a2d0f6e925f9d7e3e826f0ad2271120f8144a9db18d5c8"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:eae237477a873ab46e8dd748e515c72c0c804fb380fbe6c85533c7de51f23a8f"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:834b5230b5dfc0c1ec37b2fda433b271cbbc0e507560b5d1588e2cc1148cf1ce"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e858ac0a25074ba4bce653f9b5d0a85b7456eaddadc0ce82d3878c22489fa4ee"}, + {file = "pydantic_core-2.18.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2fd41f6eff4c20778d717af1cc50eca52f5afe7805ee530a4fbd0bae284f16e9"}, + {file = "pydantic_core-2.18.4.tar.gz", hash = "sha256:ec3beeada09ff865c344ff3bc2f427f5e6c26401cc6113d77e372c3fdac73864"}, ] [[package]] @@ -2128,17 +1925,6 @@ files = [ {file = "pygments-2.18.0.tar.gz", hash = "sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199"}, ] -[[package]] -name = "pyjwt" -version = "2.8.0" -requires_python = ">=3.7" -summary = "JSON Web Token implementation in Python" -groups = ["default"] -files = [ - {file = "PyJWT-2.8.0-py3-none-any.whl", hash = "sha256:59127c392cc44c2da5bb3192169a91f429924e17aff6534d70fdc02ab3e04320"}, - {file = "PyJWT-2.8.0.tar.gz", hash = "sha256:57e28d156e3d5c10088e0c68abb90bfac3df82b40a71bd0daa20c65ccd5c23de"}, -] - [[package]] name = "pyopengl" version = "3.1.7" @@ -2334,17 +2120,6 @@ files = [ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, ] -[[package]] -name = "python-multipart" -version = "0.0.9" -requires_python = ">=3.8" -summary = "A streaming multipart parser for Python" -groups = ["default"] -files = [ - {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, - {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, -] - [[package]] name = "pytinyrenderer" version = "0.0.14" @@ -2357,6 +2132,28 @@ files = [ {file = "pytinyrenderer-0.0.14.tar.gz", hash = "sha256:5fedb4798509cb911a03a3bc9e8de8d4d5aa36b1de52eb878efef104b95a3d15"}, ] +[[package]] +name = "pytorch-lightning" +version = "2.3.0" +requires_python = ">=3.8" +summary = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." +groups = ["default"] +dependencies = [ + "PyYAML>=5.4", + "fsspec[http]>=2022.5.0", + "lightning-utilities>=0.8.0", + "numpy>=1.17.2", + "packaging>=20.0", + "torch>=2.0.0", + "torchmetrics>=0.7.0", + "tqdm>=4.57.0", + "typing-extensions>=4.4.0", +] +files = [ + {file = "pytorch-lightning-2.3.0.tar.gz", hash = "sha256:89caf90e3543b314508493f26e0eca8d5e10e43e3d9e6c143acd8ddceb584ce2"}, + {file = "pytorch_lightning-2.3.0-py3-none-any.whl", hash = "sha256:b8eec361f4342ca628d0d8e6985511c9515435e4db62c5e982bb1c53a5a5140a"}, +] + [[package]] name = "pytorch2jax" version = "0.1.0" @@ -2400,17 +2197,6 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] -[[package]] -name = "readchar" -version = "4.1.0" -requires_python = ">=3.8" -summary = "Library to easily read single chars and key strokes" -groups = ["default"] -files = [ - {file = "readchar-4.1.0-py3-none-any.whl", hash = "sha256:d163680656b34f263fb5074023db44b999c68ff31ab394445ebfd1a2a41fe9a2"}, - {file = "readchar-4.1.0.tar.gz", hash = "sha256:6f44d1b5f0fd93bd93236eac7da39609f15df647ab9cea39f5bc7478b3344b99"}, -] - [[package]] name = "requests" version = "2.32.3" @@ -2485,34 +2271,6 @@ files = [ {file = "ruff-0.4.9.tar.gz", hash = "sha256:f1cb0828ac9533ba0135d148d214e284711ede33640465e706772645483427e3"}, ] -[[package]] -name = "runs" -version = "1.2.2" -requires_python = ">=3.8" -summary = "🏃 Run a block of text as a subprocess 🏃" -groups = ["default"] -dependencies = [ - "xmod", -] -files = [ - {file = "runs-1.2.2-py3-none-any.whl", hash = "sha256:0980dcbc25aba1505f307ac4f0e9e92cbd0be2a15a1e983ee86c24c87b839dfd"}, - {file = "runs-1.2.2.tar.gz", hash = "sha256:9dc1815e2895cfb3a48317b173b9f1eac9ba5549b36a847b5cc60c3bf82ecef1"}, -] - -[[package]] -name = "s3transfer" -version = "0.10.1" -requires_python = ">= 3.8" -summary = "An Amazon S3 Transfer Manager" -groups = ["default"] -dependencies = [ - "botocore<2.0a.0,>=1.33.2", -] -files = [ - {file = "s3transfer-0.10.1-py3-none-any.whl", hash = "sha256:ceb252b11bcf87080fb7850a224fb6e05c8a776bab8f2b64b7f25b969464839d"}, - {file = "s3transfer-0.10.1.tar.gz", hash = "sha256:5683916b4c724f799e600f41dd9e10a9ff19871bf87623cc8f491cb4f5fa0a19"}, -] - [[package]] name = "scipy" version = "1.13.1" @@ -2550,7 +2308,7 @@ files = [ [[package]] name = "sentry-sdk" -version = "2.3.1" +version = "2.6.0" requires_python = ">=3.6" summary = "Python client for Sentry (https://sentry.io)" groups = ["default"] @@ -2559,8 +2317,8 @@ dependencies = [ "urllib3>=1.26.11", ] files = [ - {file = "sentry_sdk-2.3.1-py2.py3-none-any.whl", hash = "sha256:c5aeb095ba226391d337dd42a6f9470d86c9fc236ecc71cfc7cd1942b45010c6"}, - {file = "sentry_sdk-2.3.1.tar.gz", hash = "sha256:139a71a19f5e9eb5d3623942491ce03cf8ebc14ea2e39ba3e6fe79560d8a5b1f"}, + {file = "sentry_sdk-2.6.0-py2.py3-none-any.whl", hash = "sha256:422b91cb49378b97e7e8d0e8d5a1069df23689d45262b86f54988a7db264e874"}, + {file = "sentry_sdk-2.6.0.tar.gz", hash = "sha256:65cc07e9c6995c5e316109f138570b32da3bd7ff8d0d0ee4aaf2628c3dd8127d"}, ] [[package]] @@ -2603,13 +2361,28 @@ files = [ [[package]] name = "setuptools" -version = "70.0.0" +version = "70.1.0" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default"] files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, + {file = "setuptools-70.1.0-py3-none-any.whl", hash = "sha256:d9b8b771455a97c8a9f3ab3448ebe0b29b5e105f1228bba41028be116985a267"}, + {file = "setuptools-70.1.0.tar.gz", hash = "sha256:01a1e793faa5bd89abc851fa15d0a0db26f160890c7102cd8dce643e886b47f5"}, +] + +[[package]] +name = "simple-parsing" +version = "0.1.5" +requires_python = ">=3.7" +summary = "A small utility for simplifying and cleaning up argument parsing scripts." +groups = ["default"] +dependencies = [ + "docstring-parser~=0.15", + "typing-extensions>=4.5.0", +] +files = [ + {file = "simple_parsing-0.1.5-py3-none-any.whl", hash = "sha256:46f35ed7002f9bb25dca3a49eac491cc78d2140e4adcbe156225ae643c2874ea"}, + {file = "simple_parsing-0.1.5.tar.gz", hash = "sha256:d26ac15be5173cf28174e171a68153c11e462ad2cb3c23d3ad8634b00719d1fc"}, ] [[package]] @@ -2634,17 +2407,6 @@ files = [ {file = "smmap-5.0.1.tar.gz", hash = "sha256:dceeb6c0028fdb6734471eb07c0cd2aae706ccaecab45965ee83f11c8d3b1f62"}, ] -[[package]] -name = "sniffio" -version = "1.3.1" -requires_python = ">=3.7" -summary = "Sniff out which async library your code is running under" -groups = ["default"] -files = [ - {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, - {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, -] - [[package]] name = "soupsieve" version = "2.5" @@ -2656,35 +2418,6 @@ files = [ {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, ] -[[package]] -name = "starlette" -version = "0.22.0" -requires_python = ">=3.7" -summary = "The little ASGI library that shines." -groups = ["default"] -dependencies = [ - "anyio<5,>=3.4.0", -] -files = [ - {file = "starlette-0.22.0-py3-none-any.whl", hash = "sha256:b5eda991ad5f0ee5d8ce4c4540202a573bb6691ecd0c712262d0bc85cf8f2c50"}, - {file = "starlette-0.22.0.tar.gz", hash = "sha256:b092cbc365bea34dd6840b42861bdabb2f507f8671e642e8272d2442e08ea4ff"}, -] - -[[package]] -name = "starsessions" -version = "1.3.0" -requires_python = ">=3.6.2,<4.0.0" -summary = "Pluggable session support for Starlette." -groups = ["default"] -dependencies = [ - "itsdangerous<3.0.0,>=2.0.1", - "starlette<1,>=0", -] -files = [ - {file = "starsessions-1.3.0-py3-none-any.whl", hash = "sha256:c0758f2a1a2438ec7ba88b232e82008f2261a75584f01179c787b3636fae6040"}, - {file = "starsessions-1.3.0.tar.gz", hash = "sha256:8d3b509d4e6d235655f7dd495fcf0afc1bd86da84de3a8d434e6f82137ebcde8"}, -] - [[package]] name = "submitit" version = "1.5.1" @@ -2791,20 +2524,20 @@ files = [ [[package]] name = "tensorstore" -version = "0.1.60" +version = "0.1.62" requires_python = ">=3.9" summary = "Read and write large, multi-dimensional arrays" groups = ["default"] dependencies = [ "ml-dtypes>=0.3.1", - "numpy>=1.16.0", + "numpy>=1.22.0", ] files = [ - {file = "tensorstore-0.1.60-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:65677e21304fcf272557f195c597704f4ccf55b75314e68ece17bb1784cb59f7"}, - {file = "tensorstore-0.1.60-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:725d1f70c17838815704805d2853c636bb2d680424e81f91677a7defea68373b"}, - {file = "tensorstore-0.1.60-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c477a0e6948326c414ed1bcdab2949e975f0b4e7e449cce39e0fec14b273e1b2"}, - {file = "tensorstore-0.1.60-cp312-cp312-win_amd64.whl", hash = "sha256:32cba3cf0ae6dd03d504162b8ea387f140050e279cf23e7eced68d3c845693da"}, - {file = "tensorstore-0.1.60.tar.gz", hash = "sha256:88da8f1978982101b8dbb144fd29ee362e4e8c97fc595c4992d555f80ce62a79"}, + {file = "tensorstore-0.1.62-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:616cd5d55ff6e2979d6f4578ad76c1d12dfdb361d43edfd90728b558857f33b9"}, + {file = "tensorstore-0.1.62-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6824d3f49fc2c75c7a5da1b77e840014565660852bff2544c38ccafbe63ed5a7"}, + {file = "tensorstore-0.1.62-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78e081786b293bf3a4acf2ae54d62d25c82a21ad9503c0986ba6fcf03c6c9fbb"}, + {file = "tensorstore-0.1.62-cp312-cp312-win_amd64.whl", hash = "sha256:446e46dfd149ab516fdf47598684fc472b206afcd5e365e0e3e55c7f280cc288"}, + {file = "tensorstore-0.1.62.tar.gz", hash = "sha256:d0e88dae5d983e500700f9f1636eaa742f9e673b4a230d7126f1380e021f373f"}, ] [[package]] @@ -2854,15 +2587,15 @@ files = [ [[package]] name = "torch-jax-interop" version = "0.0.4.post7.dev0" -requires_python = "<4.0,>=3.11" +requires_python = ">=3.11,<4.0" git = "https://www.github.com/lebrice/torch_jax_interop" revision = "3a4261f949d739cfe684280203137114b169e70e" summary = "Utility to convert Tensors from Jax to Torch and vice-versa" groups = ["default"] dependencies = [ - "flax<1.0.0,>=0.8.4", - "jax[cuda12]<1.0.0,>=0.4.28", - "pytorch2jax<1.0.0,>=0.1.0", + "flax<0.9.0,>=0.8.4", + "jax[cuda12]<0.5.0,>=0.4.28", + "pytorch2jax<0.2.0,>=0.1.0", "torch<3.0.0,>=2.0.0", ] @@ -2915,20 +2648,9 @@ files = [ {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, ] -[[package]] -name = "traitlets" -version = "5.14.3" -requires_python = ">=3.8" -summary = "Traitlets Python configuration system" -groups = ["default"] -files = [ - {file = "traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f"}, - {file = "traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7"}, -] - [[package]] name = "trimesh" -version = "4.4.0" +version = "4.4.1" requires_python = ">=3.7" summary = "Import, export, process, analyze and view triangular meshes." groups = ["default"] @@ -2936,30 +2658,19 @@ dependencies = [ "numpy>=1.20", ] files = [ - {file = "trimesh-4.4.0-py3-none-any.whl", hash = "sha256:e192458da391c1b0a850df0b713c59234a6582e641569b004b588ada337b05c0"}, - {file = "trimesh-4.4.0.tar.gz", hash = "sha256:daf6e56715de2e93dd905e926f9bb10d23dc4157f9724aa7caab5d0e28963e56"}, -] - -[[package]] -name = "types-python-dateutil" -version = "2.9.0.20240316" -requires_python = ">=3.8" -summary = "Typing stubs for python-dateutil" -groups = ["default"] -files = [ - {file = "types-python-dateutil-2.9.0.20240316.tar.gz", hash = "sha256:5d2f2e240b86905e40944dd787db6da9263f0deabef1076ddaed797351ec0202"}, - {file = "types_python_dateutil-2.9.0.20240316-py3-none-any.whl", hash = "sha256:6b8cb66d960771ce5ff974e9dd45e38facb81718cc1e208b10b1baccbfdbee3b"}, + {file = "trimesh-4.4.1-py3-none-any.whl", hash = "sha256:dc00e293f4efed692b57e95ff9dafd5b62f2126439fb377d2a6b048d7d086933"}, + {file = "trimesh-4.4.1.tar.gz", hash = "sha256:767fe3c866ba74e6d9a9d216c34ecc1cfe2fbf3f129a6c11d59871705a591aba"}, ] [[package]] name = "typing-extensions" -version = "4.12.0" +version = "4.12.2" requires_python = ">=3.8" summary = "Backported and Experimental Type Hints for Python 3.8+" groups = ["default", "dev"] files = [ - {file = "typing_extensions-4.12.0-py3-none-any.whl", hash = "sha256:b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594"}, - {file = "typing_extensions-4.12.0.tar.gz", hash = "sha256:8cbcdc8606ebcb0d95453ad7dc5065e6237b6aa230a31e81d0f440c30fed5fd8"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -2975,28 +2686,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" requires_python = ">=3.8" summary = "HTTP library with thread-safe connection pooling, file post, and more." groups = ["default"] files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, -] - -[[package]] -name = "uvicorn" -version = "0.30.0" -requires_python = ">=3.8" -summary = "The lightning-fast ASGI server." -groups = ["default"] -dependencies = [ - "click>=7.0", - "h11>=0.8", -] -files = [ - {file = "uvicorn-0.30.0-py3-none-any.whl", hash = "sha256:78fa0b5f56abb8562024a59041caeb555c86e48d0efdd23c3fe7de7a4075bdab"}, - {file = "uvicorn-0.30.0.tar.gz", hash = "sha256:f678dec4fa3a39706bbf49b9ec5fc40049d42418716cea52b53f07828a60aa37"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [[package]] @@ -3029,53 +2725,6 @@ files = [ {file = "wandb-0.17.2-py3-none-win_amd64.whl", hash = "sha256:62cd707f38b5711971729dae80343b8c35f6003901e690166cc6d526187a9785"}, ] -[[package]] -name = "wcwidth" -version = "0.2.13" -summary = "Measures the displayed width of unicode strings in a terminal" -groups = ["default"] -files = [ - {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, - {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, -] - -[[package]] -name = "websocket-client" -version = "1.8.0" -requires_python = ">=3.8" -summary = "WebSocket client for Python with low level API options" -groups = ["default"] -files = [ - {file = "websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526"}, - {file = "websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da"}, -] - -[[package]] -name = "websockets" -version = "11.0.3" -requires_python = ">=3.7" -summary = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -groups = ["default"] -files = [ - {file = "websockets-11.0.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f2e58f2c36cc52d41f2659e4c0cbf7353e28c8c9e63e30d8c6d3494dc9fdedcf"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de36fe9c02995c7e6ae6efe2e205816f5f00c22fd1fbf343d4d18c3d5ceac2f5"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0ac56b661e60edd453585f4bd68eb6a29ae25b5184fd5ba51e97652580458998"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e052b8467dd07d4943936009f46ae5ce7b908ddcac3fda581656b1b19c083d9b"}, - {file = "websockets-11.0.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:42cc5452a54a8e46a032521d7365da775823e21bfba2895fb7b77633cce031bb"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e6316827e3e79b7b8e7d8e3b08f4e331af91a48e794d5d8b099928b6f0b85f20"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8531fdcad636d82c517b26a448dcfe62f720e1922b33c81ce695d0edb91eb931"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c114e8da9b475739dde229fd3bc6b05a6537a88a578358bc8eb29b4030fac9c9"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e063b1865974611313a3849d43f2c3f5368093691349cf3c7c8f8f75ad7cb280"}, - {file = "websockets-11.0.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:92b2065d642bf8c0a82d59e59053dd2fdde64d4ed44efe4870fa816c1232647b"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0ee68fe502f9031f19d495dae2c268830df2760c0524cbac5d759921ba8c8e82"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcacf2c7a6c3a84e720d1bb2b543c675bf6c40e460300b628bab1b1efc7c034c"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b67c6f5e5a401fc56394f191f00f9b3811fe843ee93f4a70df3c389d1adf857d"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d5023a4b6a5b183dc838808087033ec5df77580485fc533e7dab2567851b0a4"}, - {file = "websockets-11.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:ed058398f55163a79bb9f06a90ef9ccc063b204bb346c4de78efc5d15abfe602"}, - {file = "websockets-11.0.3-py3-none-any.whl", hash = "sha256:6681ba9e7f8f3b19440921e99efbb40fc89f26cd71bf539e45d8c8a25c976dc6"}, - {file = "websockets-11.0.3.tar.gz", hash = "sha256:88fc51d9a26b10fc331be344f1781224a375b78488fc343620184e95a4b27016"}, -] - [[package]] name = "werkzeug" version = "3.0.3" @@ -3090,17 +2739,6 @@ files = [ {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, ] -[[package]] -name = "xmod" -version = "1.8.1" -requires_python = ">=3.8" -summary = "🌱 Turn any object into a module 🌱" -groups = ["default"] -files = [ - {file = "xmod-1.8.1-py3-none-any.whl", hash = "sha256:a24e9458a4853489042522bdca9e50ee2eac5ab75c809a91150a8a7f40670d48"}, - {file = "xmod-1.8.1.tar.gz", hash = "sha256:38c76486b9d672c546d57d8035df0beb7f4a9b088bc3fb2de5431ae821444377"}, -] - [[package]] name = "yarl" version = "1.9.4" @@ -3133,11 +2771,11 @@ files = [ [[package]] name = "zipp" -version = "3.19.1" +version = "3.19.2" requires_python = ">=3.8" summary = "Backport of pathlib-compatible object wrapper for zip files" groups = ["default"] files = [ - {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, - {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, + {file = "zipp-3.19.2-py3-none-any.whl", hash = "sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c"}, + {file = "zipp-3.19.2.tar.gz", hash = "sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19"}, ] diff --git a/pyproject.toml b/pyproject.toml index 13b0a7d7..6f7025a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "tqdm>=4.66.2", "hydra-zen>=0.12.1", "gym==0.26.2", - "lightning==1.9.0", + "lightning>=2.3.0", "gdown>=5.1.0", "hydra-submitit-launcher>=1.2.0", "wandb>=0.16.4", @@ -27,6 +27,8 @@ dependencies = [ "gymnax>=0.0.8", "torch-jax-interop @ git+https://www.github.com/lebrice/torch_jax_interop", "tensor-regression @ git+https://www.github.com/lebrice/tensor_regression", + "simple-parsing>=0.1.5", + "pydantic==2.7.4", ] requires-python = ">=3.12" readme = "README.md" From 4612eab4ad260bb580864c5a273acf617eb1695d Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 20 Jun 2024 20:54:29 +0000 Subject: [PATCH 25/27] Fix integration tests that have to do with dataset Signed-off-by: Fabrice Normandin --- .../image_classification/imagenet32_test.py | 2 ++ project/utils/testutils.py | 11 +++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/project/datamodules/image_classification/imagenet32_test.py b/project/datamodules/image_classification/imagenet32_test.py index 96696049..18549607 100644 --- a/project/datamodules/image_classification/imagenet32_test.py +++ b/project/datamodules/image_classification/imagenet32_test.py @@ -3,10 +3,12 @@ import pytest from project.utils.env_vars import DATA_DIR, SCRATCH +from project.utils.testutils import IN_GITHUB_CI from .imagenet32 import ImageNet32DataModule +@pytest.mark.skipif(IN_GITHUB_CI, reason="Can't run ") @pytest.mark.slow def test_dataset_download_works(): batch_size = 16 diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 617c999e..184a0605 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -15,7 +15,6 @@ from pathlib import Path from typing import Any, TypeVar -import hydra.errors import hydra_zen import numpy as np import pytest @@ -41,8 +40,8 @@ logger = get_logger(__name__) -on_github_ci = "GITHUB_ACTIONS" in os.environ -on_self_hosted_github_ci = on_github_ci and "self-hosted" in os.environ.get("RUNNER_LABELS", "") +IN_GITHUB_CI = "GITHUB_ACTIONS" in os.environ +IN_SELF_HOSTED_GITHUB_CI = IN_GITHUB_CI and "self-hosted" in os.environ.get("RUNNER_LABELS", "") default_marks_for_config_name: dict[str, list[pytest.MarkDecorator]] = { @@ -58,10 +57,10 @@ ], "imagenet": [ pytest.mark.slow, - pytest.mark.xfail( + pytest.mark.skipif( not (NETWORK_DIR and (NETWORK_DIR / "datasets/imagenet").exists()), - strict=True, - raises=hydra.errors.InstantiationException, + # strict=True, + # raises=hydra.errors.InstantiationException, reason="Expects to be run on a cluster with the ImageNet dataset.", ), ], From 5020921bc401be98f0f5aa08da8221ae846eedda Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 20 Jun 2024 20:55:56 +0000 Subject: [PATCH 26/27] Fix typo Signed-off-by: Fabrice Normandin --- project/datamodules/image_classification/imagenet32_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/datamodules/image_classification/imagenet32_test.py b/project/datamodules/image_classification/imagenet32_test.py index 18549607..537c91ce 100644 --- a/project/datamodules/image_classification/imagenet32_test.py +++ b/project/datamodules/image_classification/imagenet32_test.py @@ -8,7 +8,7 @@ from .imagenet32 import ImageNet32DataModule -@pytest.mark.skipif(IN_GITHUB_CI, reason="Can't run ") +@pytest.mark.skipif(IN_GITHUB_CI, reason="Can't run on the GitHub CI.") @pytest.mark.slow def test_dataset_download_works(): batch_size = 16 From 681cee0ba88fa0f9841135c2421603424ca0e8b9 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 21 Jun 2024 22:01:52 +0000 Subject: [PATCH 27/27] Update devcontainer file Signed-off-by: Fabrice Normandin --- .devcontainer/devcontainer.json | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 25923af4..9fc952f1 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -66,11 +66,9 @@ // https://code.visualstudio.com/remote/advancedcontainers/add-local-file-mount // Mount a directory which will contain the pdm installation cache (shared with the host machine). // This will use $SCRATCH/.cache/pdm, otherwise - "source=${localEnv:SCRATCH:~}/.cache/pdm,target=/home/vscode/.pdm_install_cache,type=bind,consistency=cached", // Mount a "$SCRATCH" directory in the host to ~/scratch in the container. - // FIXME: This assumes that either the SCRATCH environment variable is set on the host, or - // that the $HOME/scratch directory exists. - "source=${localEnv:SCRATCH:~/scratch},target=/home/vscode/scratch,type=bind,consistency=cached", + "source=${localEnv:SCRATCH},target=/home/vscode/scratch,type=bind,consistency=cached", + "source=${localEnv:SCRATCH}/.cache/pdm,target=/home/vscode/.pdm_install_cache,type=bind,consistency=cached", // Mount a /network to match the /network directory on the host. // FIXME: This assumes that either the NETWORK_DIR environment variable is set on the host, or // that the /network directory exists. @@ -89,8 +87,8 @@ // create the pdm cache dir on the host machine if it doesn exist yet so the mount above // doesn't fail. "initializeCommand": { - "create pdm install cache": "mkdir -p ~/.cache/pdm", // todo: put this on $SCRATCH on the host (e.g. compute node) - "create fake SLURM_TMPDIR": "mkdir -p /tmp/slurm_tmpdir" // this is fine on compute nodes + "create pdm install cache": "mkdir -p ${SCRATCH?need the SCRATCH environment variable to be set.}/.cache/pdm", // todo: put this on $SCRATCH on the host (e.g. compute node) + "create fake SLURM_TMPDIR": "mkdir -p ${SLURM_TMPDIR?need the SLURM_TMPDIR environment variable to be set.}" // this is fine on compute nodes }, // NOTE: Getting some permission issues with the .cache dir if mounting .cache/pdm to // .cache/pdm in the container. Therefore, here I'm making a symlink from ~/.cache/pdm to