Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix immediate StopIteration in DataLoaderShard #3368

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Aleko2286
Copy link

Fixes DataLoaderShard to also immediately return StopIteration when its dataloader immediately returns StopIteration.

Fixes #3367

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Aleko2286
Copy link
Author

This should be functionally identical to main now, except for two things:

  1. It doesn't error on empty dataloaders
  2. On main, self._update_state_dict() is called twice before yielding the final element while here it is only called once:
def __iter__(self):
        if self.rng_types is not None:
            synchronize_rng_states(self.rng_types, self.synchronized_generator)
        self.begin()

        self.set_epoch(self.iteration)
        dataloader_iter = self.base_dataloader.__iter__()
        # We iterate one batch ahead to check when we are at the end
        try:
            current_batch = next(dataloader_iter)
        except StopIteration:
            yield

        batch_index = 0
        while True:
            try:
                # But we still move it to the device so it is done before `StopIteration` is reached
                if self.device is not None:
                    current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
                # First call for final element
                self._update_state_dict()
                # This gets a StopIteration for the final element
                next_batch = next(dataloader_iter)
                if batch_index >= self.skip_batches:
                    yield current_batch
                batch_index += 1
                current_batch = next_batch
            except StopIteration:
                self.end_of_dataloader = True
                # Second call
                self._update_state_dict()
                if batch_index >= self.skip_batches:
                    yield current_batch
                break

        self.iteration += 1
        self.end()

From the description of what self._update_state_dict() does, this shouldn't be intentional as the lookahead would be accounted for twice. If it is intentional however, the PR needs to be updated again.

@Aleko2286
Copy link
Author

Now it should be completely identical except for not crashing on empty iterators. I am still not sure if calling self._update_state_dict() a second time is intentional, but it's setting "_iterator_finished", which would never be set otherwise with the current position of self._update_state_dict() in the loop.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and for ensuring that the tests pass. Would you please also add your initial snippet as a unit test to ensure that it works and keeps working?

I'm not super knowledgeable about why these specific steps are needed for iteration, so these are just my thoughts but @muellerzr should have the final say.

First, I can see that your PR tries to minimize the code blocks inside the try ... except, which is a good thing. Also, it removes the empty yield, which looks like an improvement to me, but maybe I'm missing something that would require that.

Overall, the handling of this still looks quite complicated to me, with the explicit next calling and catching of StopIteration. I wonder if there would not be a way to directly iterate over the data loader, along the lines of:

for batch_index, batch in enumerate(self.base_dataloader):
    ...  # do the book keeping and device handling
    if batch_index >= self.skip_batches:
        yield batch

self.iteration += 1
self.end

I see the comment that We iterate one batch ahead to check when we are at the end but I'm not sure what it does tbh. If we really need this, we could use itertools.tee (see the lookahead example of the tee docs).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants