Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add padding-free to bamba #35861

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

garrett361
Copy link

@garrett361 garrett361 commented Jan 23, 2025

What does this PR do?

Adds padding-free training to the BambaModel, enabling more efficient training with causal masking between disjoint sequences.

Performance: approximately 2x throughput improvements over naive padding for supervised finetuning on the Tulu v3 dataset with open-instruct. Tokens/sec/gpu plots for batch_size_per_gpu = 4:

8 A100s: 600 --> 1200 Tok/s/gpu

Scherm­afbeelding 2025-01-16 om 3 52 41 PM

32 A100s: 450 --> 750 Tok/s/gpu

Scherm­afbeelding 2025-01-16 om 3 52 33 PM

CC @fabianlim

CC reviewers of #34982: @ArthurZucker @molbap

Notes on Code

  • BambaAttention layers are untouched; only the BambaMixer mamba layer code is altered.
  • The padding-free path is only supported on cuda and requires the mamba kernels.
  • Supports both the position_ids and FlashAttentionKwargs padding-free code paths.

Notes on Tests

On both latest main and this PR branch the following tests/models/bamba/test_modeling_bamba.py tests are failing (with RUN_SLOW=1):

BambaModelTest::test_eager_matches_fa2_generate
BambaModelTest::test_flash_attention_2_padding_matches_padding_free_with_position_ids
BambaModelTest::test_sdpa_can_compile_dynamic
BambaModelTest::test_torchscript_output_attentions
BambaModelTest::test_torchscript_output_hidden_state
BambaModelTest::test_torchscript_simple
BambaModelIntegrationTest::test_simple_generate
  • The test_eager_matches_fa2_generate test seems flaky: sometimes it passes, other times it fails.
  • For test_flash_attention_2_padding_matches_padding_free_with_position_ids:
    • On main, this test fails because padding-free is not implemented.
    • On this PR branch this test fails because this PR only uses position_ids when model.training = True and this test explicitly calls eval() on the model. I have checked that this test passes when model.training = True. Edit: see BambaModelTest::test_attn_mask_position_ids_flash_attn_equality, also.
  • test_simple_generate appears to just need a simple edit for its expected text. It consistently fails with:
AssertionError: '<|be[35 chars]on this lovely evening? I hope you are all doing well. I am' != '<|be[35 chars]on this lovely evening? I hope you are all having a good time.'

where the generated and expected text differs at the very end.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from cdaf1e6 to eab1ae1 Compare January 23, 2025 20:48
@garrett361 garrett361 closed this Jan 23, 2025
@garrett361 garrett361 reopened this Jan 24, 2025
@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from eab1ae1 to c4874af Compare January 24, 2025 14:53
@Rocketknight1
Copy link
Member

cc @ArthurZucker for bamba, but let me know if you want me to take a look since it seems like quite an extensive PR!

@garrett361
Copy link
Author

it seems like quite an extensive PR!

I don't think it's very many changes, ultimately! Basically it just adds two helper functions so that position_ids and FlashAttentionKwargs get properly converted to the seq_idx arg that mamba expects:

  • get_cu_seq_lens_from_position_ids
  • get_seq_idx_from_cu_seq_lens

So, basically the above, making sure **kwargs get passed everywhere they should, and a little code cleanup.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from 49d007c to d35bcc6 Compare January 24, 2025 20:04
@garrett361
Copy link
Author

Added a commit with BambaModelTest::test_attn_mask_position_ids_flash_attn_equality which tests the various code paths against each other.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch 4 times, most recently from dfaca13 to 5d39d5e Compare January 28, 2025 16:56
@garrett361
Copy link
Author

Hi @ArthurZucker @molbap, please let me know if I can answer any questions about this PR, thank you!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Good work and good PR, one thing is that for training we do recommend using the padding-free data collator that takes care of the flattening and passing approriate kwargs. This prevents us from having to do too many changes, appart from poping the cu seqlens if they are inputs

@@ -940,10 +931,39 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super accurate to use positions ids because they can include padding

Copy link
Author

@garrett361 garrett361 Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought about this for a while and it's why I am checking the non_increasing_pos_id condition.

The (just-updated; fixed an error) helper now reads:

def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
    batch_size = position_ids.shape[0]
    if batch_size != 1:
        raise ValueError("Only batch size 1 is supported.")
    device = position_ids.device
    idxs = torch.arange(1, position_ids.shape[1], device=device)
    non_increasing_pos_id = position_ids[0, 1:] <= position_ids[0, :-1]
    next_pos_is_is_zero = position_ids[0, 1:] == 0
    new_seq_idxs = non_increasing_pos_id | next_pos_is_is_zero
    cu_seq_lens = torch.empty(new_seq_idxs.sum() + 2, device=device, dtype=torch.int64)
    cu_seq_lens[0], cu_seq_lens[1:-1], cu_seq_lens[-1] = 0, idxs[new_seq_idxs], position_ids.shape[-1]
    return cu_seq_lens

My goal was to treat every padding token (assumed to be encoded as a negative number, like -100) as an individual sequence, while treating the non-padding sequences correctly.

So, an extreme case like like position_ids = [-100, 0, 1, -100, -100, 0, 1, 2]) would turn into:

# get_cu_seq_lens_from_position_ids
cu_seq_lens = [0, 1, 3, 4, 5, 8]

thanks to the non-increasing pos id check. Seems reasonable to me, since:

  1. The non-trivial sequences are assigned the correct lengths
  2. Every padding tok is a new, len 1 segment, so no unnecessary compute will be expended in attending across a span of padding toks.

Compare to what is done in prepare_fa2_from_positions_ids, which I was trying to improve upon:

cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
)
)

Passing the same position_ids into the above gives:

# prepare_fa2_from_positions_ids
cu_seq_lens = [1, 5, 8]

which incorrectly implies there's a subseq of len 4, since it only checks for pos_id = 0.

Thoughts? Very open to improvements!

Comment on lines 951 to 964
def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
seq_idx = torch.cat(
[
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1))
]
)
return seq_idx[None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should first create a tensor and fill it , having a single tensor allocation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I minimized the allocations in all three helper functions.

Comment on lines 1023 to 1042
if not self.training:
seq_idx = None
elif "cu_seq_lens_k" in kwargs:
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
elif position_ids is not None:
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
if len(cu_seq_lens) == 2:
# If cu_seq_lens only has two elements, then it is semantically equivalent to
# `seq_idx=None`, which is more efficient.
seq_idx = None
else:
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
else:
seq_idx = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not self.training:
seq_idx = None
elif "cu_seq_lens_k" in kwargs:
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
elif position_ids is not None:
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
if len(cu_seq_lens) == 2:
# If cu_seq_lens only has two elements, then it is semantically equivalent to
# `seq_idx=None`, which is more efficient.
seq_idx = None
else:
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
else:
seq_idx = None
seq_idx = None
if "cu_seq_lens_k" in kwargs:
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
elif position_ids is not None:
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
if len(cu_seq_lens) != 2:
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens)

this is not exact but you can get rid of a lot of branches!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(to refactor a bit)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, much better now.

@@ -1079,6 +1117,13 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about cat and for loops

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, also changed now, thanks.

attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seq_idx: Optional[torch.Tensor] = None,
seqence_index: Optional[torch.Tensor] = None,

not sure what this is used for (msssing bit of doc)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mamba kernels use seq_idx to encode the packing of multiple sequences together, rather than cu_seq_lens or position_ids. The three are related like:

position_ids = [0, 1, 2, 0, 1, 0, 1, 2, 3, 4, 0, ... ] # position of each token within their seq
cu_seq_lens = [0, 3, 5, 9, 10, ...] # cumulative seq lens across seqs
seq_idx = [0, 0, 0, 1, 1, 2, 2, 2, 2, 3, ...] # idx the sequences directly

I'll add this info to the BambaMixer doc string.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from 5d39d5e to 6fdd9a0 Compare February 5, 2025 14:08
@garrett361
Copy link
Author

Good work and good PR

Thank you!

one thing is that for training we do recommend using the padding-free data collator that takes care of the flattening and passing approriate kwargs.

Yep, just trying to support all code paths. Hope I didn't misunderstand you here.

I think I have addressed all comments @ArthurZucker , let me know!

I also had to run modular_model_convertor.py on all models to get past CI; some other PR seemed to cause minor CI breakage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants