Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add gradient testing for Flash Attention 2 #35780

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

crStiv
Copy link

@crStiv crStiv commented Jan 20, 2025

This commit adds gradient testing for Flash Attention 2 (FA2) implementation.
The test ensures that gradients computed with FA2 match those computed with
the eager implementation within specified tolerance thresholds.

Key changes:

  • Added gradient testing in train mode
  • Compare gradients between eager and SDPA implementations
  • Use same tolerance thresholds as forward pass comparison
  • Properly cleanup gradients and restore eval mode

This addresses the TODO comment in test_modeling_common.py and improves
test coverage for FA2 implementation.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I think the test can be simplified : no longer needs 2 attention models as you can just set the config._attn_implementation for newer models!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants