diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index ba68cc1f2de1..8617a0c342fd 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -872,6 +872,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, seq_idx: Optional[torch.Tensor] = None, + **kwargs, ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( @@ -944,17 +945,14 @@ def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.L torch.tensor(position_ids[0].shape, device=device), ), ) - return cu_seq_lens[None] + return cu_seq_lens def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - batch_size = cu_seq_lens.shape[0] - if batch_size != 1: - raise ValueError("Only batch size 1 is supported.") seq_idx = torch.cat( [ torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device) - for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1)) + for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1)) ] ) return seq_idx[None] @@ -1028,7 +1026,7 @@ def forward( seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"]) elif position_ids is not None: cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids) - if len(cu_seq_lens[0]) == 2: + if len(cu_seq_lens) == 2: # If cu_seq_lens only has two elements, then it is semantically equivalent to # `seq_idx=None`, which is more efficient. seq_idx = None @@ -1244,6 +1242,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: + if ( + self.training + and (position_ids is not None or "cu_seq_lens_k" in flash_attn_kwargs) + and (self.config._attn_implementation != "flash_attention_2" or not is_fast_path_available) + ): + raise ValueError( + "Padding-free training using position_ids or FlashAttentionKwargs requires ", + "the flash_attention_2 attention implementation and mamba cuda and triton kernels.", + ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 82a0d174ed80..119b9afa0b49 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -220,17 +220,14 @@ def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.L torch.tensor(position_ids[0].shape, device=device), ), ) - return cu_seq_lens[None] + return cu_seq_lens def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: - batch_size = cu_seq_lens.shape[0] - if batch_size != 1: - raise ValueError("Only batch size 1 is supported.") seq_idx = torch.cat( [ torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device) - for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1)) + for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1)) ] ) return seq_idx[None] @@ -678,6 +675,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, seq_idx: Optional[torch.Tensor] = None, + **kwargs, ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( @@ -776,7 +774,7 @@ def forward( seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"]) elif position_ids is not None: cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids) - if len(cu_seq_lens[0]) == 2: + if len(cu_seq_lens) == 2: # If cu_seq_lens only has two elements, then it is semantically equivalent to # `seq_idx=None`, which is more efficient. seq_idx = None @@ -992,6 +990,15 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: + if ( + self.training + and (position_ids is not None or "cu_seq_lens_k" in flash_attn_kwargs) + and (self.config._attn_implementation != "flash_attention_2" or not is_fast_path_available) + ): + raise ValueError( + "Padding-free training using position_ids or FlashAttentionKwargs requires ", + "the flash_attention_2 attention implementation and mamba cuda and triton kernels.", + ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 5e879ea46a01..683f7ce142ed 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -18,11 +18,15 @@ import unittest import pytest +from pytest import mark from transformers import AutoTokenizer, BambaConfig, is_torch_available +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.models.bamba.modular_bamba import get_cu_seq_lens_from_position_ids, get_seq_idx_from_cu_seq_lens from transformers.testing_utils import ( + require_flash_attn, require_torch, + require_torch_gpu, slow, torch_device, ) @@ -482,6 +486,90 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): # They should result in very similar logits torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + def test_attn_mask_position_ids_flash_attn_equality(self): + r""" + Verify that the logits agree when using an attention mask, position_ids, or + FlashAttentionKwargs. + """ + torch.manual_seed(42) + decoder_only_classes = [] + for model_class in self.all_generative_model_classes: + config, _, _, _ = self.model_tester.prepare_config_and_inputs() + if config.is_encoder_decoder: + continue + else: + decoder_only_classes.append(model_class) + if len(decoder_only_classes) == 0: + self.skipTest(reason="No decoder-only architecture available for this model.") + + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't + # added support for it yet. We skip these models for now. + has_encoder_attributes = any( + attr_name + for attr_name in config.to_dict().keys() + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" + ) + if has_encoder_attributes: + self.skipTest( + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." + ) + + for model_class in decoder_only_classes: + config, input_ids, input_mask, _ = self.model_tester.prepare_config_and_inputs() + # Padding-free requires training = True and attn_implementation="flash_attention_2" + model = ( + model_class._from_config(config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16) + .to(torch_device) + .train() + ) + + non_padding_free_inputs = {"input_ids": input_ids, "attention_mask": input_mask} + attn_mask_logits = model(**non_padding_free_inputs).logits + + # Build up padding-free tensors + padding_free_input_ids = torch.cat( + [batch[mask.bool()] for batch, mask in zip(input_ids, input_mask)], dim=-1 + )[None] + position_ids_list = [ + torch.arange(mask.sum(), device=mask.device, dtype=torch.int32) for mask in input_mask + ] + position_ids = torch.cat(position_ids_list, dim=-1)[None] + seq_lens = torch.cat( + [torch.tensor([t.numel()], device=input_mask.device, dtype=torch.int32) for t in position_ids_list], + dim=-1, + ) + cu_seq_lens = torch.cat( + [ + torch.tensor([0], device=input_mask.device, dtype=torch.int32), + seq_lens.cumsum(dim=-1, dtype=torch.int32), + ], + dim=-1, + ) + + position_ids_inputs = {"input_ids": padding_free_input_ids, "position_ids": position_ids} + position_ids_logits = model(**position_ids_inputs).logits + + flash_attn_kwargs = FlashAttentionKwargs( + cu_seq_lens_q=cu_seq_lens, + cu_seq_lens_k=cu_seq_lens, + max_length_q=input_ids.shape[-1], + max_length_k=input_ids.shape[-1], + ) + flash_attn_kwargs_logits = model(input_ids=padding_free_input_ids, **flash_attn_kwargs).logits + + attn_mask_logits_reshaped = torch.cat( + [batch[mask.bool()] for batch, mask in zip(attn_mask_logits, input_mask)], dim=0 + )[None] + + torch.testing.assert_close(position_ids_logits, attn_mask_logits_reshaped) + # A higher tolerance is needed for the position_ids and FlashAttentionKwargs logits to + # match, for unknown reasons. + torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits, atol=1e-3, rtol=1e-1) + assert True + @slow @require_torch @@ -598,60 +686,41 @@ def test_simple_batched_generate_with_padding(self): def test_cu_seq_lens_from_position_ids() -> None: seq_length = 256 chunks_per_batch = 4 - batch_size = 1 # Split each batch into `chunks_per_batch` sequences. - eos_idxs = ( - torch.stack([torch.randperm(seq_length) for _ in range(batch_size)], dim=0)[:, : chunks_per_batch - 1] - .sort(dim=-1) - .values - ) - seq_lens = torch.cat( - (torch.full((batch_size, 1), -1), eos_idxs, torch.full((batch_size, 1), seq_length - 1)), dim=-1 - ).diff(dim=-1) + eos_idxs = torch.randperm(seq_length)[: chunks_per_batch - 1].sort(dim=-1).values + seq_lens = torch.cat((torch.full((1,), -1), eos_idxs, torch.full((1,), seq_length - 1)), dim=-1).diff(dim=-1) # Create the corresponding position_ids and seq_idx - position_ids = torch.stack( - [ - torch.cat( - [torch.arange(s, dtype=torch.int32) for s in sl], - dim=0, - ) - for sl in seq_lens - ], + position_ids = torch.cat( + [torch.arange(s, dtype=torch.int32) for s in seq_lens], dim=0, - ) + )[None] cu_seq_lens_pred = get_cu_seq_lens_from_position_ids(position_ids) assert torch.allclose( cu_seq_lens_pred, - torch.cat( - [torch.tensor([[0]], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1 - ), + torch.cat([torch.tensor([0], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1), ) def test_seq_idx_from_cu_seq_lens() -> None: n_chunks = 5 max_chunk_len = 64 - batch_size = 1 - seq_lens = torch.randint(1, max_chunk_len, size=(batch_size, n_chunks)) - cu_seq_lens = torch.cat([torch.tensor([[0]]), seq_lens.cumsum(dim=-1)], dim=-1) + seq_lens = torch.randint(1, max_chunk_len, size=(n_chunks,)) + cu_seq_lens = torch.cat([torch.tensor([0]), seq_lens.cumsum(dim=-1)], dim=-1) seq_idx = torch.cat( [ torch.full( - ( - batch_size, - n, - ), + (n,), idx, dtype=torch.int32, device=cu_seq_lens.device, ) - for idx, n in enumerate(seq_lens[0]) + for idx, n in enumerate(seq_lens) ], dim=-1, - ) + )[None] seq_idx_pred = get_seq_idx_from_cu_seq_lens(cu_seq_lens) assert torch.allclose(seq_idx_pred, seq_idx)