Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add gradient testing for Flash Attention 2 #35780

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 55 additions & 43 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4175,52 +4175,64 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)

if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
else:
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)

if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4

# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)

_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
# Test gradients
model_eager.train()
model_sdpa.train()

with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)

if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
else:
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)

elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
# Compute gradients
loss_eager = logits_eager.mean()
loss_sdpa = logits_sdpa.mean()

loss_eager.backward()
loss_sdpa.backward()

# Compare gradients
for p_eager, p_sdpa in zip(model_eager.parameters(), model_sdpa.parameters()):
if p_eager.grad is not None and p_sdpa.grad is not None:
self.assertTrue(
torch.allclose(
p_eager.grad, p_sdpa.grad,
atol=atols[torch_device, enable_kernels, torch_dtype],
rtol=rtols[torch_device, enable_kernels, torch_dtype]
),
f"Gradients do not match for parameter {p_eager}"
)

# Reset gradients
model_eager.zero_grad()
model_sdpa.zero_grad()

# Set models back to eval mode
model_eager.eval()
model_sdpa.eval()

logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]

results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
Expand Down