-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix how we compute the final non-padding token for ForSequenceClassification models #35911
base: main
Are you sure you want to change the base?
Conversation
4620592
to
8ccde63
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @ArthurZucker for core maintainer review - but if you have too much to do, let me know and I'll find another reviewer! |
cc @Cyrilvallez actually, as core-maintainer-in-training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Nicely done, indeed much simpler than before! Just added some small comments! Let me know what you think!
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id | ||
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device) | ||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device) | ||
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit as well, but maybe argmax()
would be simpler than max().values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't an argmax
! It's actually a fake argmax
where I make a masked index array and then compute the max
. It just looks a lot like an argmax, lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless I'm mistaken, each value of 1
in the mask will take its index as value when multiplying with torch.arange
, so max
and argmax
are fully equivalent here. But it's a detail anyway!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, yes! I could just use argmax instead
loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])) | ||
|
||
pooled_logits = in_logits if in_logits is not None else logits | ||
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the reshape dim has changed, I suppose it's the exact same without the dim 1
, but just checking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this should be the same or more correct! The loss is tested by the test_pt_tf_equivalence
tests, so we should see if they go out of alignment
last_non_pad_token = -1 | ||
logger.warning_once( | ||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " | ||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding the warning everywhere, maybe we should do it the other way around and actually remove it from the few classes where it is present? It is a fairly niche case, and people using it should be aware of the caveat IMO. It's nice to have less warnings in general, WDYT? (It would allow to simplify branching as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm - I think a lot of classes actually support inputs_embeds
as well as input_ids
, so we still need this on most classes I think! It just only fires when users actually pass inputs_embeds
and not input_ids
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be aware that it breaks torch.compile
as well, would anyone want to use it! But once again, it is quite a niche case, and it's already here for some models, so I'll let you judge -- we can keep it if you think that removing it would bring confusion/error for users 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to leave the warning, because the results will generally be wrong in that case. In future, we might consider raising an error and asking users to supply an attention_mask
in that case!
We have a lot of CLM models with
ForSequenceClassification
heads. These models are supposed to use the hidden state at the final non-padding token as the input to the classification head. However, the way they compute it is a bit weird - they get the index of the leftmost token that is equal topad_token_id
and subtract 1 from it. This has a few issues:pad_token_id
is absent, that need workaroundsargmax()
, specifically that when multiple indices have the same maximum value, it always returns the smallest oneThis PR replaces that logic with simpler logic that actually searches for what we want, the rightmost non-padding token, not the token next to the leftmost padding token. This means the same logic works with left-padding, right-padding, no-padding, or even padding on both sides (I don't think any models do that, but we're ready if they do!)
Fixes #30004
Fixes #35352
Fixes #35909