Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix how we compute the final non-padding token for ForSequenceClassification models #35911

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Jan 27, 2025

We have a lot of CLM models with ForSequenceClassification heads. These models are supposed to use the hidden state at the final non-padding token as the input to the classification head. However, the way they compute it is a bit weird - they get the index of the leftmost token that is equal to pad_token_id and subtract 1 from it. This has a few issues:

  • It breaks on left-padding
  • It creates index arithmetic issues when pad_token_id is absent, that need workarounds
  • It depends on implementation details of argmax(), specifically that when multiple indices have the same maximum value, it always returns the smallest one

This PR replaces that logic with simpler logic that actually searches for what we want, the rightmost non-padding token, not the token next to the leftmost padding token. This means the same logic works with left-padding, right-padding, no-padding, or even padding on both sides (I don't think any models do that, but we're ready if they do!)

Fixes #30004
Fixes #35352
Fixes #35909

@Rocketknight1 Rocketknight1 force-pushed the fix_sequence_classification_padding_side branch from 4620592 to 8ccde63 Compare January 27, 2025 19:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1
Copy link
Member Author

cc @ArthurZucker for core maintainer review - but if you have too much to do, let me know and I'll find another reviewer!

@Rocketknight1
Copy link
Member Author

cc @Cyrilvallez actually, as core-maintainer-in-training

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Nicely done, indeed much simpler than before! Just added some small comments! Let me know what you think!

src/transformers/models/bloom/modeling_bloom.py Outdated Show resolved Hide resolved
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = torch.ne(input_ids, self.config.pad_token_id).int().to(logits.device)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).max(-1).values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit as well, but maybe argmax() would be simpler than max().values

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't an argmax! It's actually a fake argmax where I make a masked index array and then compute the max. It just looks a lot like an argmax, lol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm mistaken, each value of 1 in the mask will take its index as value when multiplying with torch.arange, so max and argmax are fully equivalent here. But it's a detail anyway!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, yes! I could just use argmax instead

Comment on lines -904 to +897
loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))

pooled_logits = in_logits if in_logits is not None else logits
loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(pooled_logits, [-1, self.num_labels]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the reshape dim has changed, I suppose it's the exact same without the dim 1, but just checking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this should be the same or more correct! The loss is tested by the test_pt_tf_equivalence tests, so we should see if they go out of alignment

Comment on lines 959 to 963
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding the warning everywhere, maybe we should do it the other way around and actually remove it from the few classes where it is present? It is a fairly niche case, and people using it should be aware of the caveat IMO. It's nice to have less warnings in general, WDYT? (It would allow to simplify branching as well)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - I think a lot of classes actually support inputs_embeds as well as input_ids, so we still need this on most classes I think! It just only fires when users actually pass inputs_embeds and not input_ids.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be aware that it breaks torch.compile as well, would anyone want to use it! But once again, it is quite a niche case, and it's already here for some models, so I'll let you judge -- we can keep it if you think that removing it would bring confusion/error for users 🤗

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to leave the warning, because the results will generally be wrong in that case. In future, we might consider raising an error and asking users to supply an attention_mask in that case!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants