-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add padding-free to bamba #35861
base: main
Are you sure you want to change the base?
Add padding-free to bamba #35861
Conversation
cdaf1e6
to
eab1ae1
Compare
eab1ae1
to
c4874af
Compare
cc @ArthurZucker for bamba, but let me know if you want me to take a look since it seems like quite an extensive PR! |
I don't think it's very many changes, ultimately! Basically it just adds two helper functions so that
So, basically the above, making sure |
49d007c
to
d35bcc6
Compare
Added a commit with |
dfaca13
to
5d39d5e
Compare
Hi @ArthurZucker @molbap, please let me know if I can answer any questions about this PR, thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Good work and good PR, one thing is that for training we do recommend using the padding-free data collator that takes care of the flattening and passing approriate kwargs. This prevents us from having to do too many changes, appart from poping the cu seqlens if they are inputs
@@ -940,10 +931,39 @@ def extra_repr(self): | |||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" | |||
|
|||
|
|||
def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not super accurate to use positions ids because they can include padding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought about this for a while and it's why I am checking the non_increasing_pos_id
condition.
The (just-updated; fixed an error) helper now reads:
def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
batch_size = position_ids.shape[0]
if batch_size != 1:
raise ValueError("Only batch size 1 is supported.")
device = position_ids.device
idxs = torch.arange(1, position_ids.shape[1], device=device)
non_increasing_pos_id = position_ids[0, 1:] <= position_ids[0, :-1]
next_pos_is_is_zero = position_ids[0, 1:] == 0
new_seq_idxs = non_increasing_pos_id | next_pos_is_is_zero
cu_seq_lens = torch.empty(new_seq_idxs.sum() + 2, device=device, dtype=torch.int64)
cu_seq_lens[0], cu_seq_lens[1:-1], cu_seq_lens[-1] = 0, idxs[new_seq_idxs], position_ids.shape[-1]
return cu_seq_lens
My goal was to treat every padding token (assumed to be encoded as a negative number, like -100) as an individual sequence, while treating the non-padding sequences correctly.
So, an extreme case like like position_ids = [-100, 0, 1, -100, -100, 0, 1, 2]
) would turn into:
# get_cu_seq_lens_from_position_ids
cu_seq_lens = [0, 1, 3, 4, 5, 8]
thanks to the non-increasing pos id check. Seems reasonable to me, since:
- The non-trivial sequences are assigned the correct lengths
- Every padding tok is a new, len 1 segment, so no unnecessary compute will be expended in attending across a span of padding toks.
Compare to what is done in prepare_fa2_from_positions_ids
, which I was trying to improve upon:
transformers/src/transformers/modeling_flash_attention_utils.py
Lines 174 to 179 in c772bff
cu_seq_lens = torch.cat( | |
( | |
indices_q[position_ids == 0], | |
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), | |
) | |
) |
Passing the same position_ids
into the above gives:
# prepare_fa2_from_positions_ids
cu_seq_lens = [1, 5, 8]
which incorrectly implies there's a subseq of len 4, since it only checks for pos_id = 0
.
Thoughts? Very open to improvements!
def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: | ||
seq_idx = torch.cat( | ||
[ | ||
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device) | ||
for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1)) | ||
] | ||
) | ||
return seq_idx[None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should first create a tensor and fill it , having a single tensor allocation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, I minimized the allocations in all three helper functions.
if not self.training: | ||
seq_idx = None | ||
elif "cu_seq_lens_k" in kwargs: | ||
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"]) | ||
elif position_ids is not None: | ||
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids) | ||
if len(cu_seq_lens) == 2: | ||
# If cu_seq_lens only has two elements, then it is semantically equivalent to | ||
# `seq_idx=None`, which is more efficient. | ||
seq_idx = None | ||
else: | ||
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens) | ||
else: | ||
seq_idx = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not self.training: | |
seq_idx = None | |
elif "cu_seq_lens_k" in kwargs: | |
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"]) | |
elif position_ids is not None: | |
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids) | |
if len(cu_seq_lens) == 2: | |
# If cu_seq_lens only has two elements, then it is semantically equivalent to | |
# `seq_idx=None`, which is more efficient. | |
seq_idx = None | |
else: | |
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens) | |
else: | |
seq_idx = None | |
seq_idx = None | |
if "cu_seq_lens_k" in kwargs: | |
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"]) | |
elif position_ids is not None: | |
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids) | |
if len(cu_seq_lens) != 2: | |
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens) |
this is not exact but you can get rid of a lot of branches!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(to refactor a bit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, much better now.
@@ -1079,6 +1117,13 @@ def _init_weights(self, module): | |||
module.weight.data[module.padding_idx].zero_() | |||
|
|||
|
|||
def get_position_ids_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about cat and for loops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, also changed now, thanks.
attention_mask: Optional[torch.Tensor] = None, | ||
seq_idx: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_idx: Optional[torch.Tensor] = None, | |
seqence_index: Optional[torch.Tensor] = None, |
not sure what this is used for (msssing bit of doc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mamba kernels use seq_idx
to encode the packing of multiple sequences together, rather than cu_seq_lens
or position_ids
. The three are related like:
position_ids = [0, 1, 2, 0, 1, 0, 1, 2, 3, 4, 0, ... ] # position of each token within their seq
cu_seq_lens = [0, 3, 5, 9, 10, ...] # cumulative seq lens across seqs
seq_idx = [0, 0, 0, 1, 1, 2, 2, 2, 2, 3, ...] # idx the sequences directly
I'll add this info to the BambaMixer
doc string.
5d39d5e
to
6fdd9a0
Compare
Thank you!
Yep, just trying to support all code paths. Hope I didn't misunderstand you here. I think I have addressed all comments @ArthurZucker , let me know! I also had to run |
What does this PR do?
Adds padding-free training to the
BambaModel
, enabling more efficient training with causal masking between disjoint sequences.Performance: approximately 2x throughput improvements over naive padding for supervised finetuning on the Tulu v3 dataset with open-instruct. Tokens/sec/gpu plots for
batch_size_per_gpu = 4
:8 A100s: 600 --> 1200 Tok/s/gpu
32 A100s: 450 --> 750 Tok/s/gpu
CC @fabianlim
CC reviewers of #34982: @ArthurZucker @molbap
Notes on Code
BambaAttention
layers are untouched; only theBambaMixer
mamba layer code is altered.cuda
and requires the mamba kernels.position_ids
andFlashAttentionKwargs
padding-free code paths.Notes on Tests
On both latest
main
and this PR branch the followingtests/models/bamba/test_modeling_bamba.py
tests are failing (withRUN_SLOW=1
):test_eager_matches_fa2_generate
test seems flaky: sometimes it passes, other times it fails.test_flash_attention_2_padding_matches_padding_free_with_position_ids
:main
, this test fails because padding-free is not implemented.position_ids
whenmodel.training = True
and this test explicitly callseval()
on the model. I have checked that this test passes whenmodel.training = True
. Edit: seeBambaModelTest::test_attn_mask_position_ids_flash_attn_equality
, also.test_simple_generate
appears to just need a simple edit for its expected text. It consistently fails with:where the generated and expected text differs at the very end.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.