Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop-in-replacement for PyTorch DataLoader #303

Merged
merged 1 commit into from
Dec 27, 2024
Merged

Add drop-in-replacement for PyTorch DataLoader #303

merged 1 commit into from
Dec 27, 2024

Conversation

mthrok
Copy link
Collaborator

@mthrok mthrok commented Dec 27, 2024

This commit adds a data loader class that is compatible with PyTorch's DataLoader.

It uses ProcessPoolExecutor, so existing Dataset implementations can be used as-is.

Currently it only supports map-style dataset.
It is not intended to make the attributes of DataLoader compatible. So far, only the functionality as Iterable are kept compatible.

A quick benchmark (loading 10000 images from ImageNet) shows that this implementation is much faster than PyTorch's implementation.

QPS and Time-to-First-batch

QPS (the higher, the better)

#workers 1 2 4 8 16 32
SPDL 251.49 454.96 773.8 1146.63 1516.97 1763.82
PyTorch 265.22 410.2 446.48 320.63 172.6 91.83

TTFB (the lower, the better)

#workers 1 2 4 8 16 32
SPDL 3.86 3.57 3.55 3.95 3.72 4.07
PyTorch 3.49 6.79 13.32 26.6 55.32 107.43
Benchmark code
import logging
import time

from spdl.dataloader import get_pytorch_dataloader
from torch.utils.data import DataLoader
from torchvision.datasets import ImageNet
from torchvision.transforms import CenterCrop, Compose, PILToTensor, Resize

def _test(dataset: str, num_workers: int, fn, batch_size: int = 32):
    print(f"{num_workers=}")
    num_items = 0
    t0 = time.monotonic()
    dataloader = fn(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        multiprocessing_context="forkserver",
    )

    for imgs, _ in dataloader:
        if num_items == 0:
            ttfb = time.monotonic() - t0
            print(f"{ttfb=}")
        num_items += len(imgs)
        if num_items > 10000:
            break
    elapsed = time.monotonic() - t0
    qps = num_items / elapsed
    print(f"{num_items=}, {elapsed=:.2f} ({qps=:.2f})")

def _main():
    logging.basicConfig(level=logging.INFO)
    root = "/home/moto/local/imagenet/"

    dataset = ImageNet(
        root=root,
        split="train",
        transform=Compose(
            [
                Resize(256),
                CenterCrop(224),
                PILToTensor(),
            ]
        ),
    )
    for num_workers in (32, 16, 8, 4, 2, 1):
        for fn in (get_pytorch_dataloader, DataLoader):
            _test(dataset, num_workers, fn)

if __name__ == "__main__":
    _main()

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 27, 2024
@facebook-github-bot
Copy link
Contributor

@facebook-github-bot has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. (Because this pull request was imported automatically, there will not be any future comments.)

This commit adds a data loader class that is compatible with
PyTorch's DataLoader.

It uses ProcessPoolExecutor, so existing Dataset implementations
can be used as-is.

Currently it only supports map-style dataset.
It is not intended to make the attributes of DataLoader compatible.
So far, only the functionality as ``Iterable`` are kept compatible.

A quick benchmark shows that this implementation is much faster
than PyTorch's implementation.

![QPS and Time-to-First-batch](https://github.com/user-attachments/assets/6ebff0ae-b523-492b-95af-87973c7b5fca)

**QPS** (the higher, the better)

| #workers |      1 |      2 |      4 |       8 |      16 |      32 |
|----------|--------|--------|--------|---------|---------|---------|
| SPDL     | 251.49 | 454.96 |  773.8 | 1146.63 | 1516.97 | 1763.82 |
| PyTorch  | 265.22 |  410.2 | 446.48 |  320.63 |   172.6 |   91.83 |

**TTFB** (the lower, the better)

| #workers |      1 |      2 |      4 |       8 |      16 |      32 |
|----------|--------|--------|--------|---------|---------|---------|
| SPDL     |   3.86 |   3.57 |   3.55 |    3.95 |    3.72 |    4.07 |
| PyTorch  |   3.49 |   6.79 |  13.32 |    26.6 |   55.32 |  107.43 |

<details><summary>Benchmark code</summary>

```
import logging
import time

from spdl.dataloader import get_pytorch_dataloader
from torch.utils.data import DataLoader
from torchvision.datasets import ImageNet
from torchvision.transforms import CenterCrop, Compose, PILToTensor, Resize

def _test(dataset: str, num_workers: int, fn, batch_size: int = 32):
    print(f"{num_workers=}")
    num_items = 0
    t0 = time.monotonic()
    dataloader = fn(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        multiprocessing_context="forkserver",
    )

    for imgs, _ in dataloader:
        if num_items == 0:
            ttfb = time.monotonic() - t0
            print(f"{ttfb=}")
        num_items += len(imgs)
        if num_items > 10000:
            break
    elapsed = time.monotonic() - t0
    qps = num_items / elapsed
    print(f"{num_items=}, {elapsed=:.2f} ({qps=:.2f})")

def _main():
    logging.basicConfig(level=logging.INFO)
    root = "/home/moto/local/imagenet/"

    dataset = ImageNet(
        root=root,
        split="train",
        transform=Compose(
            [
                Resize(256),
                CenterCrop(224),
                PILToTensor(),
            ]
        ),
    )
    for num_workers in (32, 16, 8, 4, 2, 1):
        for fn in (get_pytorch_dataloader, DataLoader):
            _test(dataset, num_workers, fn)

if __name__ == "__main__":
    _main()
```

</details>
@mthrok mthrok merged commit f9569d6 into main Dec 27, 2024
21 of 31 checks passed
@mthrok mthrok deleted the ptdl branch December 27, 2024 21:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants