diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 128571d568f5..8fd389ebe16d 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -1123,18 +1123,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 44ca1e12643f..955b809c3810 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -791,18 +791,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index af5301efa78c..fc74195d84c2 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -1217,18 +1217,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 6ea2b5b68f1f..7d63e56e0f7f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1360,18 +1360,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 6afb21a6b9c4..0d7c3288571b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -949,18 +949,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 09f916cdf378..db3ed1a5d558 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -1039,18 +1039,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 1b656c70739a..14c071abefb6 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -959,18 +959,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b9fb391bb0c1..3c7b13c89991 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1397,18 +1397,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index f647e8dd517b..ee5ce5ed9fe1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1284,18 +1284,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 70c3c8a9638d..981bddd12c3b 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -1102,18 +1102,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index b71e7d1bd700..f5346ca6f5f8 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -1207,18 +1207,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 3498104b1180..3733c14631d0 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -1244,18 +1244,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 5730964b23a3..e5c89a2d06ca 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -946,18 +946,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 7538abad96c4..948203bbf218 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1680,18 +1680,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index e1ed4a1213e3..884c06132d01 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1456,18 +1456,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1097f43e14a7..1fec59389fea 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -948,18 +948,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f225bbb1f66c..a8acd9a1f4c8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1037,18 +1037,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index f9f8b312320e..8318bc9cb28d 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1190,18 +1190,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 7922882dbcd5..499a8db9cf23 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -682,18 +682,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index cf610e9c80ea..be6593ac829d 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -1195,18 +1195,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index d6836268c5ad..734e3ef27c01 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -807,18 +807,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 6b395cca3d71..6c6d245c6256 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -1299,18 +1299,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index cd4da514c8f2..d0080fdb9617 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -1010,18 +1010,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 5755dcba07c9..ad8c532fb8ff 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -922,18 +922,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index b4e7920f3b4b..842221e7162e 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1058,18 +1058,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 088a13e3570c..d5d0c4a94731 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1598,18 +1598,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 0796bbf061f6..58486f23c2c7 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -933,18 +933,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 885bb09b1101..c30ff64573f1 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1439,18 +1439,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 011b1b5a66e9..602ee81330d4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1266,18 +1266,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 8df08b02cedc..f5beb88a024e 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -945,18 +945,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 999ba029dc7e..1b2e129489dd 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1433,18 +1433,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 069fdcb3b37e..c53885ad15b5 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1859,18 +1859,17 @@ def forward( raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).max(-1).values else: - if input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device) - last_non_pad_token = (token_indices * non_pad_mask).max(-1).values - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]