Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix condition when GA loss bug fix is not performed #35651

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

techkang
Copy link
Contributor

What does this PR do?

Fixes #35649

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @hiyouga @ArthurZucker @SunMarc

@@ -3709,7 +3706,7 @@ def training_step(
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hiyouga let us know if your tests go okay! If they do I'll give this a ✔️

@@ -855,7 +855,7 @@ def tokenize_function(examples):
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")

# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 3")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 3")
self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

@hiyouga
Copy link
Contributor

hiyouga commented Jan 13, 2025

@muellerzr The loss value and gradients can be aligned with the latest main branch after applying this fix. We tested with the Qwen2-VL-7B model and gradient accumulation steps = 8
image

@techkang
Copy link
Contributor Author

I added an extra validation that guarantee the fixed loss and unfixed loss should not vary too much.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for identifying, checking, and wonderful collaboration all around 🤗 Arguably one of the hardest tricks we've had to deal with in the Trainer in quite some time, so we really appreciate everyones effort in making sure we get it fully right for everyone.

@ArthurZucker we're good to go 🚀

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Contributor

muellerzr commented Jan 13, 2025

@techkang @hiyouga tentatively something seems off here. One of my two tests fail. We succeed with test_gradient_accumulation_loss_alignment_with_loss_func, but not with_model_loss.

My modified version of the test I'm working on:

    def test_gradient_accumulation_loss_alignment_with_model_loss(self):
        set_seed(42)
        import datasets

        model_name = "nickypro/tinyllama-15M"
        dataset_name = "wikitext"
        dataset_config = "wikitext-2-raw-v1"
        dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:40]")
        dataset = dataset.train_test_split(test_size=0.2)
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        tokenizer.pad_token = tokenizer.eos_token

        def tokenize_function(examples):
            return tokenizer(examples["text"], max_length=16, padding="max_length", truncation=True)

        tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)

        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

        model = AutoModelForCausalLM.from_pretrained(model_name)
        state_dict = model.state_dict()

        base_loss_callback = StoreLossCallback()

        args_kwargs = {
            "report_to": "none",
            "logging_steps": 1,
            "max_steps": 5,
            "learning_rate": 3e-4,
            "disable_tqdm": True,
        }

        with tempfile.TemporaryDirectory() as tmp_dir:
            args = TrainingArguments(
                tmp_dir,
                **args_kwargs,
            )
            trainer = Trainer(
                model,
                args,
                train_dataset=tokenized_dataset["train"],
                callbacks=[base_loss_callback],
                data_collator=data_collator,
            )
            assert trainer.model_accepts_loss_kwargs
            trainer.train()

        grad_accum_loss_callback = StoreLossCallback()
        with tempfile.TemporaryDirectory() as tmp_dir:
            args = TrainingArguments(
                tmp_dir,
                **args_kwargs,
                gradient_accumulation_steps=2,
                per_device_train_batch_size=4,
            )
            set_seed(42)
            model.load_state_dict(state_dict)
            trainer = Trainer(
                model,
                args,
                train_dataset=tokenized_dataset["train"],
                callbacks=[grad_accum_loss_callback],
                data_collator=data_collator,
            )
            trainer.train()

            set_seed(42)
            model.load_state_dict(state_dict)
            broken_loss_callback = StoreLossCallback()
            trainer = Trainer(
                model,
                args,
                train_dataset=tokenized_dataset["train"],
                callbacks=[broken_loss_callback],
                data_collator=data_collator,
            )
            # disable model_accepts_loss_kwargs
            trainer.model_accepts_loss_kwargs = False
            trainer.train()

            # Calculate the difference between the base loss and the grad_accum loss
            diff_truth = [
                abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
            ]
            diff_broken = [
                abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)
            ]

            # all diff truth should be quite close
            self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")

            # max diff broken should be very off
            self.assertGreater(max(diff_broken), 2, f"Difference {max(diff_broken)} is not greater than 2")

@hiyouga
Copy link
Contributor

hiyouga commented Jan 13, 2025

Hi @muellerzr , I am not very familiar with your test cases, but we tried fine-tuning the Llama3 model and FYI the loss and gradients look good to me

image

@muellerzr
Copy link
Contributor

muellerzr commented Jan 13, 2025

@techkang @hiyouga yeah the issue was the load_state_dict. For some reason one chunk of it is sensitive.

Post fixing my own tests, we're fully green on both ✅

(I'll make a PR with these new tests after!)

@hiyouga
Copy link
Contributor

hiyouga commented Jan 13, 2025

@muellerzr It might be better to also add test cases for model does not accept loss_kwargs, WDYT?

@muellerzr
Copy link
Contributor

@hiyouga sure if you'd like to build on what I have started with #35668 :)

@techkang
Copy link
Contributor Author

@hiyouga The model used in test_gradient_accumulation_loss_alignment_with_loss_func doesn't accept loss_kwargs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PR #35438 introduced a new bug
4 participants