-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix immediate StopIteration in DataLoaderShard #3368
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This should be functionally identical to main now, except for two things:
def __iter__(self):
if self.rng_types is not None:
synchronize_rng_states(self.rng_types, self.synchronized_generator)
self.begin()
self.set_epoch(self.iteration)
dataloader_iter = self.base_dataloader.__iter__()
# We iterate one batch ahead to check when we are at the end
try:
current_batch = next(dataloader_iter)
except StopIteration:
yield
batch_index = 0
while True:
try:
# But we still move it to the device so it is done before `StopIteration` is reached
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
# First call for final element
self._update_state_dict()
# This gets a StopIteration for the final element
next_batch = next(dataloader_iter)
if batch_index >= self.skip_batches:
yield current_batch
batch_index += 1
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
# Second call
self._update_state_dict()
if batch_index >= self.skip_batches:
yield current_batch
break
self.iteration += 1
self.end() From the description of what self._update_state_dict() does, this shouldn't be intentional as the lookahead would be accounted for twice. If it is intentional however, the PR needs to be updated again. |
Now it should be completely identical except for not crashing on empty iterators. I am still not sure if calling self._update_state_dict() a second time is intentional, but it's setting "_iterator_finished", which would never be set otherwise with the current position of self._update_state_dict() in the loop. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR and for ensuring that the tests pass. Would you please also add your initial snippet as a unit test to ensure that it works and keeps working?
I'm not super knowledgeable about why these specific steps are needed for iteration, so these are just my thoughts but @muellerzr should have the final say.
First, I can see that your PR tries to minimize the code blocks inside the try ... except
, which is a good thing. Also, it removes the empty yield
, which looks like an improvement to me, but maybe I'm missing something that would require that.
Overall, the handling of this still looks quite complicated to me, with the explicit next
calling and catching of StopIteration
. I wonder if there would not be a way to directly iterate over the data loader, along the lines of:
for batch_index, batch in enumerate(self.base_dataloader):
... # do the book keeping and device handling
if batch_index >= self.skip_batches:
yield batch
self.iteration += 1
self.end
I see the comment that We iterate one batch ahead to check when we are at the end
but I'm not sure what it does tbh. If we really need this, we could use itertools.tee
(see the lookahead
example of the tee
docs).
Fixes DataLoaderShard to also immediately return StopIteration when its dataloader immediately returns StopIteration.
Fixes #3367
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.