From b480fff12781f7b8c1a38bcc04e5801337a8d8a8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Dec 2024 13:54:00 +0100 Subject: [PATCH 01/14] add native liger-kernl orpo loss --- tests/test_orpo_trainer.py | 46 +++++++++- trl/trainer/orpo_config.py | 3 + trl/trainer/orpo_trainer.py | 163 ++++++++++++++++++++++++------------ 3 files changed, 157 insertions(+), 55 deletions(-) diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index d2eaee3947..a0d1a23d4d 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -19,7 +19,7 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_peft +from transformers.testing_utils import require_liger_kernel, require_peft from trl import ORPOConfig, ORPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE @@ -148,3 +148,47 @@ def test_orpo_trainer_with_lora(self, config_name): # check the params have changed - ignore 0 biases if param.sum() != 0: self.assertFalse(torch.equal(param, new_param)) + + @require_liger_kernel + def test_orpo_trainer_with_liger(self): + """Test ORPO trainer with Liger loss enabled.""" + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + logging_steps=1, + beta=0.1, + report_to="none", + use_liger_loss=True, # Enable Liger loss + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = ORPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Verify Liger is being used + self.assertTrue(trainer._using_liger) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index b7e2ef7ad0..3e8585dfad 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -61,6 +61,8 @@ class ORPOConfig(TrainingArguments): string. dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. """ learning_rate: float = 1e-6 @@ -76,3 +78,4 @@ class ORPOConfig(TrainingArguments): is_encoder_decoder: Optional[bool] = None model_init_kwargs: Optional[dict[str, Any]] = None dataset_num_proc: Optional[int] = None + use_liger_loss: bool = False diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index f94522923b..7326ce38d9 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -47,7 +47,7 @@ ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput -from transformers.utils import is_peft_available, is_torch_fx_proxy +from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy from transformers.utils.deprecation import deprecate_kwarg from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt @@ -68,7 +68,6 @@ if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training - if is_wandb_available(): import wandb @@ -357,6 +356,22 @@ def make_inputs_require_grad(module, input, output): "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) + # Import Liger loss if enabled + if self.args.use_liger_loss and is_liger_kernel_available(): + try: + from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss + + self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) + self._using_liger = True + except ImportError: + warnings.warn( + "Liger package not found. Falling back to default ORPO loss implementation. " + "Install liger-kernel for optimized performance." + ) + self._using_liger = False + else: + self._using_liger = False + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin @@ -756,51 +771,74 @@ def concatenated_forward( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], use_cache=False, + output_hidden_states=True if self._using_liger else False, **model_kwargs, ) - all_logits = outputs.logits - - def cross_entropy_loss(logits, labels): - if not self.is_encoder_decoder: - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - logits = logits.view(-1, logits.shape[-1]) - labels = labels.view(-1) - # Enable model parallelism - labels = labels.to(logits.device) - loss = loss_fct(logits, labels) - return loss - if self.is_encoder_decoder: - labels = concatenated_batch["concatenated_labels"].clone() + if self._using_liger: + lm_head = model.get_output_embeddings() + + # Get the last hidden state from hidden_states tuple + last_hidden_state = outputs.hidden_states[-1] + + # return the final loss and aux_outputs tuple + return self.orpo_loss_fn( + lm_head.weight, + last_hidden_state, + concatenated_batch["concatenated_labels"], + lm_head.bias if hasattr(lm_head, "bias") else None, + ) else: - labels = concatenated_batch["concatenated_input_ids"].clone() - attention_mask = concatenated_batch["concatenated_attention_mask"] - labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) - chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) - all_logps = self.get_batch_logps( - all_logits, - concatenated_batch["concatenated_labels"], - average_log_prob=True, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) - chosen_logps = all_logps[:len_chosen] - rejected_logps = all_logps[len_chosen:] + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] - chosen_logits = all_logits[:len_chosen] - rejected_logits = all_logits[len_chosen:] + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) + if self.aux_loss_enabled: + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + outputs.aux_loss, + ) - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) def get_batch_loss_metrics( self, @@ -812,21 +850,41 @@ def get_batch_loss_metrics( metrics = {} forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] + if self._using_liger: + # full ORPO loss and aux outputs + ( + loss, + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + chosen_rewards, + rejected_rewards, + log_odds_ratio, + log_odds_chosen, + ), + ) = forward_output + else: + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( - policy_chosen_logps, policy_rejected_logps - ) - # full ORPO loss - loss = policy_nll_loss - losses.mean() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss reward_accuracies = (chosen_rewards > rejected_rewards).float() @@ -846,8 +904,6 @@ def get_batch_loss_metrics( xm.mark_step() # needed because .item() calls for k, v in metrics.items(): metrics[k] = v.item() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss return loss, metrics @@ -859,7 +915,6 @@ def compute_loss( num_items_in_batch=None, ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext() - with compute_loss_context_manager: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") From 44aa20c56ea2b50630a4a27be8d9c122b30f9477 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Dec 2024 16:52:53 +0100 Subject: [PATCH 02/14] Update tests/test_orpo_trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- tests/test_orpo_trainer.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index a0d1a23d4d..88f94f7b6c 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -155,14 +155,6 @@ def test_orpo_trainer_with_liger(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = ORPOConfig( output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - remove_unused_columns=False, - gradient_accumulation_steps=1, - learning_rate=9e-1, - eval_strategy="steps", - logging_steps=1, - beta=0.1, report_to="none", use_liger_loss=True, # Enable Liger loss ) From 7682e3192602d56645612dcdd8ba6f9ec5b40be7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Dec 2024 16:57:36 +0100 Subject: [PATCH 03/14] passing self.args.use_liger_loss without liger installed should raised an error --- trl/trainer/orpo_trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 7326ce38d9..90f7b6e260 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -357,18 +357,22 @@ def make_inputs_require_grad(module, input, output): ) # Import Liger loss if enabled - if self.args.use_liger_loss and is_liger_kernel_available(): + if self.args.use_liger_loss: + if not is_liger_kernel_available(): + raise ValueError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) try: from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) self._using_liger = True except ImportError: - warnings.warn( - "Liger package not found. Falling back to default ORPO loss implementation. " - "Install liger-kernel for optimized performance." + raise ImportError( + "Failed to import LigerFusedLinearORPOLoss from liger-kernel. " + "Please ensure you have the correct liger-kernel version installed." ) - self._using_liger = False else: self._using_liger = False From c383bf6116a5ef384c811f2709b967a0e24080c1 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Dec 2024 17:46:45 +0100 Subject: [PATCH 04/14] update liger version --- setup.py | 2 +- tests/test_orpo_trainer.py | 3 --- trl/trainer/orpo_trainer.py | 28 ++++++++++++---------------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/setup.py b/setup.py index 28d483e0c8..7b355eac6b 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,7 @@ "diffusers": ["diffusers>=0.18.0"], "judges": ["openai>=1.23.2", "llm-blender>=0.0.2"], # liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility - "liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"], + "liger": ["liger-kernel>=0.5.1; sys_platform != 'win32'"], "mergekit": ["mergekit>=0.0.5.1"], "peft": ["peft>=0.8.0"], "quantization": ["bitsandbytes"], diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 88f94f7b6c..a4f5ad6973 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -169,9 +169,6 @@ def test_orpo_trainer_with_liger(self): eval_dataset=dummy_dataset["test"], ) - # Verify Liger is being used - self.assertTrue(trainer._using_liger) - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 90f7b6e260..204d7b078e 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -363,18 +363,9 @@ def make_inputs_require_grad(module, input, output): "You set `use_liger_loss=True` but the liger kernel is not available. " "Please install liger-kernel first: `pip install liger-kernel`" ) - try: - from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss - - self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) - self._using_liger = True - except ImportError: - raise ImportError( - "Failed to import LigerFusedLinearORPOLoss from liger-kernel. " - "Please ensure you have the correct liger-kernel version installed." - ) - else: - self._using_liger = False + from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss + + self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 @@ -775,23 +766,28 @@ def concatenated_forward( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], use_cache=False, - output_hidden_states=True if self._using_liger else False, + output_hidden_states=True if self.args.use_liger_loss else False, **model_kwargs, ) - if self._using_liger: + if self.args.use_liger_loss: lm_head = model.get_output_embeddings() # Get the last hidden state from hidden_states tuple last_hidden_state = outputs.hidden_states[-1] # return the final loss and aux_outputs tuple - return self.orpo_loss_fn( + loss, aux_outputs = self.orpo_loss_fn( lm_head.weight, last_hidden_state, concatenated_batch["concatenated_labels"], lm_head.bias if hasattr(lm_head, "bias") else None, ) + + if self.aux_loss_enabled: + loss += self.aux_loss_coef * outputs.aux_loss + + return loss, aux_outputs else: all_logits = outputs.logits @@ -854,7 +850,7 @@ def get_batch_loss_metrics( metrics = {} forward_output = self.concatenated_forward(model, batch) - if self._using_liger: + if self.args.use_liger_loss: # full ORPO loss and aux outputs ( loss, From 220f7541d69579dbf7696ca582ed8ba5c1381201 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Dec 2024 17:49:49 +0100 Subject: [PATCH 05/14] make import more readable --- trl/trainer/orpo_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 204d7b078e..98b7bfafe3 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -77,6 +77,9 @@ if is_torch_xla_available(): import torch_xla.core.xla_model as xm +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss + class ORPOTrainer(Trainer): r""" @@ -363,8 +366,6 @@ def make_inputs_require_grad(module, input, output): "You set `use_liger_loss=True` but the liger kernel is not available. " "Please install liger-kernel first: `pip install liger-kernel`" ) - from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss - self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) def _prepare_deepspeed(self, model: PreTrainedModelWrapper): From b3f327037740cc41aad7aa96476b2a7e55c19aef Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 17 Dec 2024 11:09:38 +0100 Subject: [PATCH 06/14] skip the lm_head when use_liger_loss is true --- trl/trainer/orpo_config.py | 4 ++++ trl/trainer/orpo_trainer.py | 28 ++++++++++++++++------------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index 3e8585dfad..93a743171c 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -63,6 +63,9 @@ class ORPOConfig(TrainingArguments): Number of processes to use for processing the dataset. use_liger_loss (`bool`, *optional*, defaults to `False`): Whether to use Liger loss. + base_model_class_name (`str`, *optional*, defaults to `"model"`): + The name of the base model class (e.g. `"model"` for `LlamaForCausalLM`) to use for skipping the + LM head when `use_liger_loss` is `True`. """ learning_rate: float = 1e-6 @@ -79,3 +82,4 @@ class ORPOConfig(TrainingArguments): model_init_kwargs: Optional[dict[str, Any]] = None dataset_num_proc: Optional[int] = None use_liger_loss: bool = False + base_model_class_name: str = "model" diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 98b7bfafe3..11a1e728ab 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -763,24 +763,21 @@ def concatenated_forward( if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True - outputs = model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - output_hidden_states=True if self.args.use_liger_loss else False, - **model_kwargs, - ) - if self.args.use_liger_loss: + # skip the lm head and get the last hidden state + base_model = getattr(model, self.args.base_model_class_name) + outputs = base_model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) lm_head = model.get_output_embeddings() - # Get the last hidden state from hidden_states tuple - last_hidden_state = outputs.hidden_states[-1] - # return the final loss and aux_outputs tuple loss, aux_outputs = self.orpo_loss_fn( lm_head.weight, - last_hidden_state, + outputs[0], concatenated_batch["concatenated_labels"], lm_head.bias if hasattr(lm_head, "bias") else None, ) @@ -790,6 +787,13 @@ def concatenated_forward( return loss, aux_outputs else: + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + output_hidden_states=False, + **model_kwargs, + ) all_logits = outputs.logits def cross_entropy_loss(logits, labels): From afaf5a86d0649b9f4295ac041e66e70172709560 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 17 Dec 2024 11:56:37 +0100 Subject: [PATCH 07/14] use get_decoder() --- trl/trainer/orpo_config.py | 4 ---- trl/trainer/orpo_trainer.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index 93a743171c..3e8585dfad 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -63,9 +63,6 @@ class ORPOConfig(TrainingArguments): Number of processes to use for processing the dataset. use_liger_loss (`bool`, *optional*, defaults to `False`): Whether to use Liger loss. - base_model_class_name (`str`, *optional*, defaults to `"model"`): - The name of the base model class (e.g. `"model"` for `LlamaForCausalLM`) to use for skipping the - LM head when `use_liger_loss` is `True`. """ learning_rate: float = 1e-6 @@ -82,4 +79,3 @@ class ORPOConfig(TrainingArguments): model_init_kwargs: Optional[dict[str, Any]] = None dataset_num_proc: Optional[int] = None use_liger_loss: bool = False - base_model_class_name: str = "model" diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 11a1e728ab..ce914a61d5 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -765,7 +765,7 @@ def concatenated_forward( if self.args.use_liger_loss: # skip the lm head and get the last hidden state - base_model = getattr(model, self.args.base_model_class_name) + base_model = model.get_decoder() outputs = base_model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], From 5776a4e3541f25e260b71a3a2b2ba882760c7c36 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 17 Dec 2024 13:03:02 +0100 Subject: [PATCH 08/14] make it a bit more robust --- trl/trainer/orpo_config.py | 4 ++++ trl/trainer/orpo_trainer.py | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py index 3e8585dfad..1cd5fa0a00 100644 --- a/trl/trainer/orpo_config.py +++ b/trl/trainer/orpo_config.py @@ -63,6 +63,9 @@ class ORPOConfig(TrainingArguments): Number of processes to use for processing the dataset. use_liger_loss (`bool`, *optional*, defaults to `False`): Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from the + model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. """ learning_rate: float = 1e-6 @@ -79,3 +82,4 @@ class ORPOConfig(TrainingArguments): model_init_kwargs: Optional[dict[str, Any]] = None dataset_num_proc: Optional[int] = None use_liger_loss: bool = False + base_model_attribute_name: str = "model" diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index ce914a61d5..ec4c0b3b98 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -765,7 +765,10 @@ def concatenated_forward( if self.args.use_liger_loss: # skip the lm head and get the last hidden state - base_model = model.get_decoder() + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name) outputs = base_model( concatenated_batch["concatenated_input_ids"], attention_mask=concatenated_batch["concatenated_attention_mask"], @@ -777,7 +780,7 @@ def concatenated_forward( # return the final loss and aux_outputs tuple loss, aux_outputs = self.orpo_loss_fn( lm_head.weight, - outputs[0], + outputs.last_hidden_state, concatenated_batch["concatenated_labels"], lm_head.bias if hasattr(lm_head, "bias") else None, ) From 568e21a27f84059d1b4e021ef8d83d7d39681fae Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 19 Dec 2024 11:38:36 +0100 Subject: [PATCH 09/14] add back missing line --- trl/trainer/orpo_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 924838afee..a55a226e84 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -824,6 +824,9 @@ def cross_entropy_loss(logits, labels): label_pad_token_id=self.label_pad_token_id, ) + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + if not self.is_encoder_decoder: chosen_logits = all_logits[:len_chosen, :-1, :] rejected_logits = all_logits[len_chosen:, :-1, :] From f4979b0f150c32cf54b57ce5e1bff5359742982c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 19 Dec 2024 12:02:42 +0100 Subject: [PATCH 10/14] pass is_enc_dec --- trl/trainer/orpo_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index a55a226e84..89382ab316 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -366,7 +366,9 @@ def make_inputs_require_grad(module, input, output): "You set `use_liger_loss=True` but the liger kernel is not available. " "Please install liger-kernel first: `pip install liger-kernel`" ) - self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) + self.orpo_loss_fn = LigerFusedLinearORPOLoss( + ignore_index=self.label_pad_token_id, beta=self.beta, is_encoder_decoder=self.is_encoder_decoder + ) def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 From 5c6744ff47333d7c4666d106661df91c9baea654 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 19 Dec 2024 20:42:01 +0100 Subject: [PATCH 11/14] call orpo_loss_fn with shifted inputs --- trl/trainer/orpo_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 89382ab316..ac21f40831 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -366,9 +366,7 @@ def make_inputs_require_grad(module, input, output): "You set `use_liger_loss=True` but the liger kernel is not available. " "Please install liger-kernel first: `pip install liger-kernel`" ) - self.orpo_loss_fn = LigerFusedLinearORPOLoss( - ignore_index=self.label_pad_token_id, beta=self.beta, is_encoder_decoder=self.is_encoder_decoder - ) + self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 @@ -782,8 +780,10 @@ def concatenated_forward( # return the final loss and aux_outputs tuple loss, aux_outputs = self.orpo_loss_fn( lm_head.weight, - outputs.last_hidden_state, - concatenated_batch["concatenated_labels"], + outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, + concatenated_batch["concatenated_labels"][:, 1:] + if not self.is_encoder_decoder + else concatenated_batch["concatenated_labels"], lm_head.bias if hasattr(lm_head, "bias") else None, ) From e1918b77b9496956efc3592d94341a5f2be2284a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 28 Dec 2024 16:41:30 +0100 Subject: [PATCH 12/14] add back the orpo nll labels --- trl/trainer/orpo_trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 0d170a38dd..a01502653e 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -816,12 +816,17 @@ def cross_entropy_loss(logits, labels): loss = loss_fct(logits, labels) return loss - labels = concatenated_batch["concatenated_labels"].clone() + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( all_logits, - labels, + concatenated_batch["concatenated_labels"], average_log_prob=True, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, From 5ee37a613467bfcd45f73748d032a6d4673737e2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 29 Dec 2024 14:55:11 +0100 Subject: [PATCH 13/14] call with nll_target --- trl/trainer/orpo_trainer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index a01502653e..fdf2699029 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -764,6 +764,14 @@ def concatenated_forward( if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True + # orpo nll target is with respect to the concatenated prompt + completionlabels + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + if self.args.use_liger_loss: # skip the lm head and get the last hidden state if hasattr(model, "get_decoder"): @@ -786,6 +794,7 @@ def concatenated_forward( if not self.is_encoder_decoder else concatenated_batch["concatenated_labels"], lm_head.bias if hasattr(lm_head, "bias") else None, + nll_target=labels[:, 1:] if not self.is_encoder_decoder else labels, ) if self.aux_loss_enabled: @@ -816,12 +825,6 @@ def cross_entropy_loss(logits, labels): loss = loss_fct(logits, labels) return loss - if self.is_encoder_decoder: - labels = concatenated_batch["concatenated_labels"].clone() - else: - labels = concatenated_batch["concatenated_input_ids"].clone() - attention_mask = concatenated_batch["concatenated_attention_mask"] - labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( From f6ffbf6bb1169b771c1558fa2aade3a89556ddbe Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 3 Jan 2025 17:02:36 +0100 Subject: [PATCH 14/14] fix enc-dec --- trl/trainer/orpo_trainer.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index ea8a49c3e2..254795533d 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -776,17 +776,31 @@ def concatenated_forward( labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) if self.args.use_liger_loss: - # skip the lm head and get the last hidden state - if hasattr(model, "get_decoder"): - base_model = model.get_decoder() + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = model.get_encoder()( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + return_dict=True, + ) + # 2. Get decoder outputs + outputs = model.get_decoder()( + input_ids=model_kwargs["decoder_input_ids"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + use_cache=False, + ) else: - base_model = getattr(model, self.args.base_model_attribute_name) - outputs = base_model( - concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], - use_cache=False, - **model_kwargs, - ) + # skip the lm head and get the last hidden state + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name) + outputs = base_model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) lm_head = model.get_output_embeddings() # return the final loss and aux_outputs tuple