-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix condition when GA loss bug fix is not performed #35651
base: main
Are you sure you want to change the base?
Conversation
@@ -3709,7 +3706,7 @@ def training_step( | |||
scaled_loss.backward() | |||
else: | |||
# Finally we need to normalize the loss for reporting | |||
if num_items_in_batch is None: | |||
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hiyouga let us know if your tests go okay! If they do I'll give this a ✔️
tests/trainer/test_trainer.py
Outdated
@@ -855,7 +855,7 @@ def tokenize_function(examples): | |||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") | |||
|
|||
# max diff broken should be very off | |||
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3") | |||
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 3") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 3") | |
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
@muellerzr The loss value and gradients can be aligned with the latest main branch after applying this fix. We tested with the Qwen2-VL-7B model and gradient accumulation steps = 8 |
I added an extra validation that guarantee the fixed loss and unfixed loss should not vary too much. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for identifying, checking, and wonderful collaboration all around 🤗 Arguably one of the hardest tricks we've had to deal with in the Trainer
in quite some time, so we really appreciate everyones effort in making sure we get it fully right for everyone.
@ArthurZucker we're good to go 🚀
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@techkang @hiyouga tentatively something seems off here. One of my two tests fail. We succeed with My modified version of the test I'm working on: def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42)
import datasets
model_name = "nickypro/tinyllama-15M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:40]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples["text"], max_length=16, padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
state_dict = model.state_dict()
base_loss_callback = StoreLossCallback()
args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 5,
"learning_rate": 3e-4,
"disable_tqdm": True,
}
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
data_collator=data_collator,
)
assert trainer.model_accepts_loss_kwargs
trainer.train()
grad_accum_loss_callback = StoreLossCallback()
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
tmp_dir,
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model.load_state_dict(state_dict)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
data_collator=data_collator,
)
trainer.train()
set_seed(42)
model.load_state_dict(state_dict)
broken_loss_callback = StoreLossCallback()
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
data_collator=data_collator,
)
# disable model_accepts_loss_kwargs
trainer.model_accepts_loss_kwargs = False
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)
]
# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
# max diff broken should be very off
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2") |
Hi @muellerzr , I am not very familiar with your test cases, but we tried fine-tuning the Llama3 model and FYI the loss and gradients look good to me |
@muellerzr It might be better to also add test cases for model does not accept |
@hiyouga The model used in |
What does this PR do?
Fixes #35649
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr @hiyouga @ArthurZucker @SunMarc