Skip to content

Commit

Permalink
fix iterator overflow when gradient accumulation is 1 (#35960)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Jan 29, 2025
1 parent 4d3b107 commit 7547f55
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,8 @@ def _inner_training_loop(
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
if args.gradient_accumulation_steps == 1:
total_updates -= 1
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
Expand Down

0 comments on commit 7547f55

Please sign in to comment.