Skip to content

Commit

Permalink
test_attn_mask_position_ids_flash_attn_equality
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett361 committed Jan 24, 2025
1 parent 8d191c7 commit d35bcc6
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 42 deletions.
19 changes: 13 additions & 6 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
**kwargs,
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
Expand Down Expand Up @@ -944,17 +945,14 @@ def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.L
torch.tensor(position_ids[0].shape, device=device),
),
)
return cu_seq_lens[None]
return cu_seq_lens


def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
batch_size = cu_seq_lens.shape[0]
if batch_size != 1:
raise ValueError("Only batch size 1 is supported.")
seq_idx = torch.cat(
[
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1))
for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1))
]
)
return seq_idx[None]
Expand Down Expand Up @@ -1028,7 +1026,7 @@ def forward(
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
elif position_ids is not None:
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
if len(cu_seq_lens[0]) == 2:
if len(cu_seq_lens) == 2:
# If cu_seq_lens only has two elements, then it is semantically equivalent to
# `seq_idx=None`, which is more efficient.
seq_idx = None
Expand Down Expand Up @@ -1244,6 +1242,15 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
if (
self.training
and (position_ids is not None or "cu_seq_lens_k" in flash_attn_kwargs)
and (self.config._attn_implementation != "flash_attention_2" or not is_fast_path_available)
):
raise ValueError(
"Padding-free training using position_ids or FlashAttentionKwargs requires ",
"the flash_attention_2 attention implementation and mamba cuda and triton kernels.",
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down
19 changes: 13 additions & 6 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,14 @@ def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.L
torch.tensor(position_ids[0].shape, device=device),
),
)
return cu_seq_lens[None]
return cu_seq_lens


def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
batch_size = cu_seq_lens.shape[0]
if batch_size != 1:
raise ValueError("Only batch size 1 is supported.")
seq_idx = torch.cat(
[
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1))
for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1))
]
)
return seq_idx[None]
Expand Down Expand Up @@ -678,6 +675,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
**kwargs,
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
Expand Down Expand Up @@ -776,7 +774,7 @@ def forward(
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
elif position_ids is not None:
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
if len(cu_seq_lens[0]) == 2:
if len(cu_seq_lens) == 2:
# If cu_seq_lens only has two elements, then it is semantically equivalent to
# `seq_idx=None`, which is more efficient.
seq_idx = None
Expand Down Expand Up @@ -992,6 +990,15 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
if (
self.training
and (position_ids is not None or "cu_seq_lens_k" in flash_attn_kwargs)
and (self.config._attn_implementation != "flash_attention_2" or not is_fast_path_available)
):
raise ValueError(
"Padding-free training using position_ids or FlashAttentionKwargs requires ",
"the flash_attention_2 attention implementation and mamba cuda and triton kernels.",
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down
129 changes: 99 additions & 30 deletions tests/models/bamba/test_modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
import unittest

import pytest
from pytest import mark

from transformers import AutoTokenizer, BambaConfig, is_torch_available
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.models.bamba.modular_bamba import get_cu_seq_lens_from_position_ids, get_seq_idx_from_cu_seq_lens
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -482,6 +486,90 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
# They should result in very similar logits
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_attn_mask_position_ids_flash_attn_equality(self):
r"""
Verify that the logits agree when using an attention mask, position_ids, or
FlashAttentionKwargs.
"""
torch.manual_seed(42)
decoder_only_classes = []
for model_class in self.all_generative_model_classes:
config, _, _, _ = self.model_tester.prepare_config_and_inputs()
if config.is_encoder_decoder:
continue
else:
decoder_only_classes.append(model_class)
if len(decoder_only_classes) == 0:
self.skipTest(reason="No decoder-only architecture available for this model.")

# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
# added support for it yet. We skip these models for now.
has_encoder_attributes = any(
attr_name
for attr_name in config.to_dict().keys()
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
)
if has_encoder_attributes:
self.skipTest(
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
)

for model_class in decoder_only_classes:
config, input_ids, input_mask, _ = self.model_tester.prepare_config_and_inputs()
# Padding-free requires training = True and attn_implementation="flash_attention_2"
model = (
model_class._from_config(config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
.to(torch_device)
.train()
)

non_padding_free_inputs = {"input_ids": input_ids, "attention_mask": input_mask}
attn_mask_logits = model(**non_padding_free_inputs).logits

# Build up padding-free tensors
padding_free_input_ids = torch.cat(
[batch[mask.bool()] for batch, mask in zip(input_ids, input_mask)], dim=-1
)[None]
position_ids_list = [
torch.arange(mask.sum(), device=mask.device, dtype=torch.int32) for mask in input_mask
]
position_ids = torch.cat(position_ids_list, dim=-1)[None]
seq_lens = torch.cat(
[torch.tensor([t.numel()], device=input_mask.device, dtype=torch.int32) for t in position_ids_list],
dim=-1,
)
cu_seq_lens = torch.cat(
[
torch.tensor([0], device=input_mask.device, dtype=torch.int32),
seq_lens.cumsum(dim=-1, dtype=torch.int32),
],
dim=-1,
)

position_ids_inputs = {"input_ids": padding_free_input_ids, "position_ids": position_ids}
position_ids_logits = model(**position_ids_inputs).logits

flash_attn_kwargs = FlashAttentionKwargs(
cu_seq_lens_q=cu_seq_lens,
cu_seq_lens_k=cu_seq_lens,
max_length_q=input_ids.shape[-1],
max_length_k=input_ids.shape[-1],
)
flash_attn_kwargs_logits = model(input_ids=padding_free_input_ids, **flash_attn_kwargs).logits

attn_mask_logits_reshaped = torch.cat(
[batch[mask.bool()] for batch, mask in zip(attn_mask_logits, input_mask)], dim=0
)[None]

torch.testing.assert_close(position_ids_logits, attn_mask_logits_reshaped)
# A higher tolerance is needed for the position_ids and FlashAttentionKwargs logits to
# match, for unknown reasons.
torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits, atol=1e-3, rtol=1e-1)
assert True


@slow
@require_torch
Expand Down Expand Up @@ -598,60 +686,41 @@ def test_simple_batched_generate_with_padding(self):
def test_cu_seq_lens_from_position_ids() -> None:
seq_length = 256
chunks_per_batch = 4
batch_size = 1

# Split each batch into `chunks_per_batch` sequences.
eos_idxs = (
torch.stack([torch.randperm(seq_length) for _ in range(batch_size)], dim=0)[:, : chunks_per_batch - 1]
.sort(dim=-1)
.values
)
seq_lens = torch.cat(
(torch.full((batch_size, 1), -1), eos_idxs, torch.full((batch_size, 1), seq_length - 1)), dim=-1
).diff(dim=-1)
eos_idxs = torch.randperm(seq_length)[: chunks_per_batch - 1].sort(dim=-1).values
seq_lens = torch.cat((torch.full((1,), -1), eos_idxs, torch.full((1,), seq_length - 1)), dim=-1).diff(dim=-1)

# Create the corresponding position_ids and seq_idx
position_ids = torch.stack(
[
torch.cat(
[torch.arange(s, dtype=torch.int32) for s in sl],
dim=0,
)
for sl in seq_lens
],
position_ids = torch.cat(
[torch.arange(s, dtype=torch.int32) for s in seq_lens],
dim=0,
)
)[None]

cu_seq_lens_pred = get_cu_seq_lens_from_position_ids(position_ids)
assert torch.allclose(
cu_seq_lens_pred,
torch.cat(
[torch.tensor([[0]], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1
),
torch.cat([torch.tensor([0], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1),
)


def test_seq_idx_from_cu_seq_lens() -> None:
n_chunks = 5
max_chunk_len = 64
batch_size = 1

seq_lens = torch.randint(1, max_chunk_len, size=(batch_size, n_chunks))
cu_seq_lens = torch.cat([torch.tensor([[0]]), seq_lens.cumsum(dim=-1)], dim=-1)
seq_lens = torch.randint(1, max_chunk_len, size=(n_chunks,))
cu_seq_lens = torch.cat([torch.tensor([0]), seq_lens.cumsum(dim=-1)], dim=-1)
seq_idx = torch.cat(
[
torch.full(
(
batch_size,
n,
),
(n,),
idx,
dtype=torch.int32,
device=cu_seq_lens.device,
)
for idx, n in enumerate(seq_lens[0])
for idx, n in enumerate(seq_lens)
],
dim=-1,
)
)[None]
seq_idx_pred = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
assert torch.allclose(seq_idx_pred, seq_idx)

0 comments on commit d35bcc6

Please sign in to comment.