Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #35438 introduced a new bug #35649

Open
2 of 4 tasks
techkang opened this issue Jan 13, 2025 · 7 comments · May be fixed by #35651
Open
2 of 4 tasks

PR #35438 introduced a new bug #35649

techkang opened this issue Jan 13, 2025 · 7 comments · May be fixed by #35651
Labels

Comments

@techkang
Copy link
Contributor

System Info

(base) MBP-HD6JD9Q599-2052 :: ~/code/transformers ‹main*› % transformers-cli env 1 ↵

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.49.0.dev0
  • Platform: macOS-14.6.1-arm64-arm-64bit
  • Python version: 3.12.4
  • Huggingface_hub version: 0.27.1
  • Safetensors version: 0.4.3
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

@muellerzr @hiyouga @ArthurZucker @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

export RUN_SLOW=True
pytest tests/trainer/test_trainer.py::TrainerIntegrationPrerunTest::test_gradient_accumulation_loss_alignment_with_loss_func

======================================================= short test summary info =======================================================
FAILED tests/trainer/test_trainer.py::TrainerIntegrationPrerunTest::test_gradient_accumulation_loss_alignment_with_loss_func - AssertionError: 3.0949999999999998 not less than 0.01 : Difference 3.0949999999999998 is not within 0.01
=================================================== 1 failed, 2 warnings in 54.91s ====================================================

Expected behavior

Test Passed.

@techkang techkang added the bug label Jan 13, 2025
@techkang
Copy link
Contributor Author

PR link: #35438

@techkang techkang changed the title PR https://github.com/huggingface/transformers/pull/35438 introduced a new bug PR #35438 introduced a new bug Jan 13, 2025
@techkang
Copy link
Contributor Author

techkang commented Jan 13, 2025

I think the PR: #35438 should be reverted and the proper way to fix the bug mentioned in the PR is as follows.

In the following code, loss is scaled when num_items_in_batch is None.
https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L3712-L3713

But the true meaning of this code is to scale loss when GA bug fix is not performed. This is not identical to num_items_in_batch is None after recent PRs. So it should be changed to

if not self.model_accepts_loss_kwargs and self.compute_loss_func is None

@muellerzr
Copy link
Contributor

Thanks! Would you like to make a PR for this? Else I can do so today

@techkang
Copy link
Contributor Author

@muellerzr Thanks for reply. I will open a PR today.

@techkang techkang linked a pull request Jan 13, 2025 that will close this issue
5 tasks
@hiyouga
Copy link
Contributor

hiyouga commented Jan 13, 2025

Hi @techkang , it has been an evidence that #35121 introduces bug making the loss of the Qwen2VL model incorrect through our rigorous experiments in #35438 . I think we should not only focus on the model with loss function but also pay attention to the models without loss_kwargs. There should be a solution that let both the two conditions work instead of simply reverting our fix. cc @muellerzr

@techkang
Copy link
Contributor Author

Hi @hiyouga , #35121 indeed introduced a bug but I don't think #35438 is the proper way to fix it. Can you try to varify the Qwen2VL loss by new PR: #35651?

@hiyouga
Copy link
Contributor

hiyouga commented Jan 13, 2025

@techkang Yep, the new PR looks better to me, let us perform some experiments on it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants