From 2355ca1ba66846446f391d66ec4125660deef447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 19 Dec 2024 15:59:54 +0000 Subject: [PATCH 01/13] first step --- trl/trainer/kto_trainer.py | 747 ++++++++++--------------------------- 1 file changed, 190 insertions(+), 557 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index d054d97e7d..0d07c81c77 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -30,7 +30,7 @@ import torch.nn.functional as F import transformers from accelerate import PartialState -from accelerate.utils import is_deepspeed_available, tqdm +from accelerate.utils import is_deepspeed_available from datasets import Dataset, concatenate_datasets from packaging import version from torch.utils.data import DataLoader, SequentialSampler @@ -168,105 +168,83 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, ** f"{kwargs['prefix']}label": example["label"], } - if not kwargs["is_encoder_decoder"]: - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") - if not isinstance(completion, str): - raise ValueError(f"completion should be an str but got {type(completion)}") + if not isinstance(completion, str): + raise ValueError(f"completion should be an str but got {type(completion)}") - # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer - all_tokens = { - "prompt_input_ids": example["prompt_input_ids"], - "prompt_attention_mask": example["prompt_attention_mask"], - "answer_input_ids": example["answer_input_ids"], - "answer_attention_mask": example["answer_attention_mask"], - } + # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer + all_tokens = { + "prompt_input_ids": example["prompt_input_ids"], + "prompt_attention_mask": example["prompt_attention_mask"], + "answer_input_ids": example["answer_input_ids"], + "answer_attention_mask": example["answer_attention_mask"], + } - # calculate max length by checking if BOS/EOS is already there - max_length = kwargs["max_length"] - bos_token_id = kwargs["tokenizer"].bos_token_id - eos_token_id = kwargs["tokenizer"].eos_token_id - if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: - max_length -= 1 - if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: - max_length -= 1 - - # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["prompt_input_ids", "prompt_attention_mask"]: - if kwargs["truncation_mode"] == "keep_start": - all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] - elif kwargs["truncation_mode"] == "keep_end": - all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] - else: - raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") - - # if that's still too long, truncate the response - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["answer_input_ids", "answer_attention_mask"]: - all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] - - # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens - batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = ( - all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] - ) - batch[f"{kwargs['prefix']}completion_attention_mask"] = ( - all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] - ) + # calculate max length by checking if BOS/EOS is already there + max_length = kwargs["max_length"] + bos_token_id = kwargs["tokenizer"].bos_token_id + eos_token_id = kwargs["tokenizer"].eos_token_id + if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: + max_length -= 1 + if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: + max_length -= 1 + + # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["prompt_input_ids", "prompt_attention_mask"]: + if kwargs["truncation_mode"] == "keep_start": + all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] + elif kwargs["truncation_mode"] == "keep_end": + all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] + else: + raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") + + # if that's still too long, truncate the response + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["answer_input_ids", "answer_attention_mask"]: + all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] + + # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens + batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] + batch[f"{kwargs['prefix']}completion_attention_mask"] = ( + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + ) - # add BOS, which affects both prompt and the full completion - if bos_token_id is not None: - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}prompt_attention_mask" - ] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] - # add EOS, which affects only the full completion - if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: - batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ - eos_token_id + # add BOS, which affects both prompt and the full completion + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ f"{kwargs['prefix']}completion_attention_mask" - ] + [1] - - batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] - batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ - kwargs["label_pad_token_id"] - ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) - else: - completion_tokens = kwargs["tokenizer"]( - completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True - ) - prompt_tokens = kwargs["tokenizer"]( - prompt, truncation=True, max_length=kwargs["max_prompt_length"], add_special_tokens=True - ) - - batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"] - - batch[f"{kwargs['prefix']}completion_labels"] = completion_tokens["input_ids"] - batch[f"{kwargs['prefix']}completion_attention_mask"] = completion_tokens["attention_mask"] - if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): - batch[f"{kwargs['prefix']}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( - labels=torch.tensor(batch["completion_labels"]) - ) + ] + # add EOS, which affects only the full completion + if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: + batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ + eos_token_id + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + [1] + + batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] + batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ + kwargs["label_pad_token_id"] + ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) return batch @@ -452,20 +430,13 @@ def make_inputs_require_grad(module, input, output): " Please install with `pip install wandb` to resolve." ) - if model is not None: - self.is_encoder_decoder = model.config.is_encoder_decoder - elif args.is_encoder_decoder is None: - raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") - else: - self.is_encoder_decoder = args.is_encoder_decoder - self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) self.model_adapter_name = model_adapter_name self.ref_adapter_name = ref_adapter_name if ref_model: self.ref_model = ref_model - elif self.is_peft_model or args.precompute_ref_log_probs: + elif self.is_peft_model or False: # The `model` with adapters turned off will be used as the reference model self.ref_model = None else: @@ -496,21 +467,12 @@ def make_inputs_require_grad(module, input, output): max_prompt_length = args.max_prompt_length max_completion_length = None - if args.max_completion_length is None and self.is_encoder_decoder: - warnings.warn( - "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_completion_length = 128 - if args.max_completion_length is not None and self.is_encoder_decoder: - max_completion_length = args.max_completion_length if data_collator is None: data_collator = DPODataCollatorWithPadding( pad_token_id=processing_class.pad_token_id, label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=self.is_encoder_decoder, + is_encoder_decoder=False, ) if args.remove_unused_columns: @@ -531,7 +493,6 @@ def make_inputs_require_grad(module, input, output): if self.ref_model is not None: disable_dropout_in_model(self.ref_model) - self.loss_type = args.loss_type self.max_length = max_length self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id @@ -540,17 +501,6 @@ def make_inputs_require_grad(module, input, output): self.truncation_mode = args.truncation_mode self.max_completion_length = max_completion_length self.processing_class = processing_class - self.precompute_ref_log_probs = args.precompute_ref_log_probs - - # Not all losses require a KL calculation - self.calculate_KL = True - if self.loss_type in ["apo_zero_unpaired"]: - self.calculate_KL = False - - # Since ref_logs are precomputed on the first call to get_train/eval_dataloader - # keep track of first called to avoid computation of future calls - self._precomputed_train_ref_log_probs = False - self._precomputed_eval_ref_log_probs = False # metric self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -559,16 +509,7 @@ def make_inputs_require_grad(module, input, output): self.beta = args.beta self.desirable_weight = args.desirable_weight self.undesirable_weight = args.undesirable_weight - self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) - if self.aux_loss_enabled and self.aux_loss_coef == 0.0: - warnings.warn( - "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " - "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " - "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " - "loss.", - UserWarning, - ) # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the @@ -622,7 +563,6 @@ def make_inputs_require_grad(module, input, output): fn_kwargs = { "prefix": "", - "is_encoder_decoder": self.is_encoder_decoder, "tokenizer": self.processing_class, "max_length": self.max_length, "truncation_mode": self.truncation_mode, @@ -656,54 +596,53 @@ def make_inputs_require_grad(module, input, output): ) # Get KL datasets if needed - if self.calculate_KL: - if args.per_device_train_batch_size <= 1: - raise ValueError( - "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward." - ) + if args.per_device_train_batch_size <= 1: + raise ValueError( + "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward." + ) + + # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size + # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n) + train_kl_dataset = train_dataset.map( + _get_kl_dataset, + batched=True, + batch_size=args.per_device_train_batch_size, + num_proc=args.dataset_num_proc, + desc="Extracting KL train dataset", + ) - # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size - # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n) - train_kl_dataset = train_dataset.map( + fn_kwargs["prefix"] = "kl_" + train_kl_dataset = train_kl_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names], + desc="Processing tokenized train KL dataset", + ) + + # merge the datasets + train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1) + + if eval_dataset is not None: + # Get KL dataset + eval_kl_dataset = eval_dataset.map( _get_kl_dataset, batched=True, batch_size=args.per_device_train_batch_size, num_proc=args.dataset_num_proc, - desc="Extracting KL train dataset", + desc="Extracting eval KL dataset", ) - fn_kwargs["prefix"] = "KL_" - train_kl_dataset = train_kl_dataset.map( + eval_kl_dataset = eval_kl_dataset.map( _process_tokens, fn_kwargs=fn_kwargs, num_proc=args.dataset_num_proc, - remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names], - desc="Processing tokenized train KL dataset", + remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names], + desc="Processing tokenized eval KL dataset", ) # merge the datasets - train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1) - - if eval_dataset is not None: - # Get KL dataset - eval_kl_dataset = eval_dataset.map( - _get_kl_dataset, - batched=True, - batch_size=args.per_device_train_batch_size, - num_proc=args.dataset_num_proc, - desc="Extracting eval KL dataset", - ) - - eval_kl_dataset = eval_kl_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names], - desc="Processing tokenized eval KL dataset", - ) - - # merge the datasets - eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1) + eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1) # calculate dataset desirability balance num_desirable = max(sum(train_dataset["label"]), 1) @@ -753,15 +692,8 @@ def make_inputs_require_grad(module, input, output): "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." ) - # Deepspeed Zero-3 does not support precompute_ref_log_probs - if self.is_deepspeed_enabled: - if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: - raise ValueError( - "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." - ) - if self.ref_model is None: - if not (self.is_peft_model or self.precompute_ref_log_probs): + if not (self.is_peft_model or False): raise ValueError( "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" ) @@ -814,346 +746,83 @@ def null_ref_context(self): if self.ref_adapter_name: self.model.set_adapter(self.model_adapter_name or "default") - def get_train_dataloader(self) -> DataLoader: - """ - Returns the training [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. - """ - - if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: - dataloader_params = { - "batch_size": self.args.per_device_train_batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) - reference_completion_logps = [] - reference_KL_logps = [] - - for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): - reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) - - reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) - reference_completion_logps.append(reference_completion_logp.cpu()) - - if self.calculate_KL: - reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) - reference_KL_logps.append(reference_KL_logp.cpu()) - - self.train_dataset = self.train_dataset.add_column( - name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() - ) - - if self.calculate_KL: - self.train_dataset = self.train_dataset.add_column( - name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() - ) - - self._precomputed_train_ref_log_probs = True - - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - """ - Returns the evaluation [`~torch.utils.data.DataLoader`]. - - Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. - - Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. - """ - if eval_dataset is None and self.eval_dataset is None: - raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: - dataloader_params = { - "batch_size": self.args.per_device_eval_batch_size, - "collate_fn": self.data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "shuffle": False, - } - - # prepare dataloader - data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - - reference_completion_logps = [] - reference_KL_logps = [] - - for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): - reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) - - reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) - reference_completion_logps.append(reference_completion_logp.cpu()) - - if self.calculate_KL: - reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) - reference_KL_logps.append(reference_KL_logp.cpu()) - - eval_dataset = eval_dataset.add_column( - name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() - ) - if self.calculate_KL: - eval_dataset = eval_dataset.add_column( - name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() - ) - - # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs - if self.eval_dataset is not None: - self.eval_dataset = eval_dataset - self._precomputed_eval_ref_log_probs = True - - return super().get_eval_dataloader(eval_dataset=eval_dataset) - - def compute_reference_log_probs(self, padded_batch: dict) -> dict: - """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset.""" - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): - if self.is_encoder_decoder: - completion_logits = self.model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - if self.calculate_KL: - KL_logits = self.model( - padded_batch["KL_prompt_input_ids"], - attention_mask=padded_batch["KL_prompt_attention_mask"], - decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), - labels=padded_batch["KL_completion_labels"], - ).logits - else: - completion_logits = self.model( - padded_batch["completion_input_ids"], - attention_mask=padded_batch["completion_attention_mask"], - ).logits - - if self.calculate_KL: - KL_logits = self.model( - padded_batch["KL_completion_input_ids"], - attention_mask=padded_batch["KL_completion_attention_mask"], - ).logits - else: - if self.is_encoder_decoder: - completion_logits = self.ref_model( - padded_batch["prompt_input_ids"], - attention_mask=padded_batch["prompt_attention_mask"], - decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), - labels=padded_batch["completion_labels"], - ).logits - - if self.calculate_KL: - KL_logits = self.ref_model( - padded_batch["KL_prompt_input_ids"], - attention_mask=padded_batch["KL_prompt_attention_mask"], - decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), - labels=padded_batch["KL_completion_labels"], - ).logits - else: - completion_logits = self.ref_model( - padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] - ).logits - - if self.calculate_KL: - KL_logits = self.ref_model( - padded_batch["KL_completion_input_ids"], - attention_mask=padded_batch["KL_completion_attention_mask"], - ).logits - - completion_logps = self.get_batch_logps( - completion_logits, - padded_batch["completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - if self.calculate_KL: - KL_logps = self.get_batch_logps( - KL_logits, - padded_batch["KL_completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - else: - KL_logps = None - - return completion_logps, KL_logps - - @staticmethod - def get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) - average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - if not is_encoder_decoder: - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - else: - # Fixes end-dec RuntimeError - labels = labels.clone() - - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == label_pad_token_id] = 0 - - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) - def forward( self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - if self.calculate_KL: - KL_logps = None - KL_model_kwargs = ( - { - "input_ids": batch["KL_prompt_input_ids"], - "attention_mask": batch["KL_prompt_attention_mask"], - "labels": batch["KL_completion_labels"], - "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"), - } - if self.is_encoder_decoder - else { - "input_ids": batch["KL_completion_input_ids"], - "attention_mask": batch["KL_completion_attention_mask"], - } - ) - with torch.no_grad(): - KL_logits = model( - **KL_model_kwargs, - ).logits - - KL_logps = self.get_batch_logps( - KL_logits, - batch["KL_completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - else: - KL_logps = None + kl_logps = None - model_kwargs = ( - { - "labels": batch["completion_labels"], - "decoder_input_ids": batch.get("completion_decoder_input_ids"), - } - if self.is_encoder_decoder - else {} - ) - if self.aux_loss_enabled: - model_kwargs["output_router_logits"] = True + with torch.no_grad(): + kl_logits = model( + input_ids=batch["kl_completion_input_ids"], attention_mask=batch["kl_completion_attention_mask"] + ).logits + + labels = batch["kl_completion_labels"][:, 1:].clone() + logits = kl_logits[:, :-1, :] + loss_mask = labels != self.label_pad_token_id + labels[labels == self.label_pad_token_id] = 0 + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + kl_logps = (per_token_logps * loss_mask).sum(-1) - outputs = model( - batch["completion_input_ids"], - attention_mask=batch["completion_attention_mask"], - **model_kwargs, - ) + outputs = model(batch["completion_input_ids"], attention_mask=batch["completion_attention_mask"]) completion_logits = outputs.logits - completion_logps = self.get_batch_logps( - completion_logits, - batch["completion_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) - - if completion_logps.shape[0] != len(batch["label"]): - raise ValueError( - "There is a mismatch between the number of examples in this batch and the number of " - "examples for which an output sequence was predicted." - ) + labels = batch["completion_labels"][:, 1:].clone() + logits = completion_logits[:, :-1, :] + loss_mask = labels != self.label_pad_token_id + labels[labels == self.label_pad_token_id] = 0 + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + completion_logps = (per_token_logps * loss_mask).sum(-1) chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] - chosen_logps = completion_logps[chosen_idx, ...] - rejected_logps = completion_logps[rejected_idx, ...] + chosen_logps = completion_logps[chosen_idx] + rejected_logps = completion_logps[rejected_idx] - chosen_logits = completion_logits[chosen_idx, ...] - rejected_logits = completion_logits[rejected_idx, ...] + chosen_logits = completion_logits[chosen_idx] + rejected_logits = completion_logits[rejected_idx] - if self.aux_loss_enabled: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss) - else: - return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) + return { + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "chosen_logits": chosen_logits, + "rejected_logits": rejected_logits, + "kl_logps": kl_logps, + } def kto_loss( self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - policy_KL_logps: torch.FloatTensor, - reference_chosen_logps: torch.FloatTensor, - reference_rejected_logps: torch.FloatTensor, - reference_KL_logps: torch.FloatTensor, + chosen_logps: torch.FloatTensor, + rejected_logps: torch.FloatTensor, + kl_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ref_kl_logps: torch.FloatTensor, ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Compute the KTO loss for a batch of policy and reference model log probabilities. Args: - policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) - policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) - policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,) - reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) - reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,) - reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,) + chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) + rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) + kl_logps: Log probabilities of the policy model for the kl responses. Shape: (batch_size,) + ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) + ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,) + ref_kl_logps: Log probabilities of the reference model for the kl responses. Shape: (batch_size,) Returns: - A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). + A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, kl). The losses tensor contains the KTO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. - The KL tensor contains the detached KL divergence estimate between the policy and reference models. + The kl tensor contains the detached kl divergence estimate between the policy and reference models. """ - if self.calculate_KL: - kl = (policy_KL_logps - reference_KL_logps).mean().detach() - kl = self.accelerator.gather(kl).mean().clamp(min=0) - else: - kl = torch.zeros(1).to(policy_chosen_logps.device) + kl = (kl_logps - ref_kl_logps).mean().detach() + kl = self.accelerator.gather(kl).mean().clamp(min=0) # Chosen losses - if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0: - chosen_logratios = policy_chosen_logps - reference_chosen_logps - - if self.loss_type == "kto": - # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) - chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) - elif self.loss_type == "apo_zero_unpaired": - # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) - # Use this loss when you believe the chosen outputs are better than your model's default output - chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios) - + if chosen_logps.shape[0] != 0 or ref_chosen_logps.shape[0] != 0: + chosen_logratios = chosen_logps - ref_chosen_logps + # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) chosen_rewards = self.beta * chosen_logratios.detach() else: @@ -1162,13 +831,9 @@ def kto_loss( chosen_rewards = torch.Tensor([]).to(self.accelerator.device) # Rejected losses - if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0: - rejected_logratios = policy_rejected_logps - reference_rejected_logps - - if self.loss_type == "kto": - rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) - elif self.loss_type == "apo_zero_unpaired": - rejected_losses = F.sigmoid(self.beta * rejected_logratios) + if rejected_logps.shape[0] != 0 or ref_rejected_logps.shape[0] != 0: + rejected_logratios = rejected_logps - ref_rejected_logps + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) rejected_rewards = self.beta * rejected_logratios.detach() else: @@ -1192,55 +857,17 @@ def get_batch_loss_metrics( metrics = {} batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} - forward_output = self.forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_KL_logps, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - # if reference_logps in batch use them, otherwise use the reference model - if "reference_logps" in batch: - chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] - rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] - - reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] - reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] - if self.calculate_KL: - reference_KL_logps = batch["reference_KL_logps"] - else: - reference_KL_logps = None - else: - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - reference_KL_logps, - ) = self.forward(self.model, batch)[:5] - else: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - reference_KL_logps, - ) = self.forward(self.ref_model, batch)[:5] + model_output = self.forward(model, batch) + with torch.no_grad(): + ref_model_output = self.forward(self.ref_model, batch) losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( - policy_chosen_logps, - policy_rejected_logps, - policy_KL_logps, - reference_chosen_logps, - reference_rejected_logps, - reference_KL_logps, + model_output["chosen_logps"], + model_output["rejected_logps"], + model_output["kl_logps"], + ref_model_output["chosen_logps"], + ref_model_output["rejected_logps"], + ref_model_output["kl_logps"], ) metrics["kl"] = kl.item() @@ -1252,19 +879,25 @@ def get_batch_loss_metrics( if all_num_chosen > 0: metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() - metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item() - metrics["logits/chosen_sum"] = self.accelerator.gather(policy_chosen_logits.nansum()).nansum().item() + metrics["logps/chosen_sum"] = ( + self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather(model_output["chosen_logits"].nansum()).nansum().item() + ) metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() - metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() - metrics["logits/rejected_sum"] = self.accelerator.gather(policy_rejected_logits.nansum()).nansum().item() + metrics["logps/rejected_sum"] = ( + self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather(model_output["rejected_logits"].nansum()).nansum().item() + ) metrics["count/rejected"] = all_num_rejected loss = losses.nanmean() - if self.aux_loss_enabled: - loss += self.aux_loss_coef * aux_loss return loss, metrics @@ -1315,13 +948,13 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) pad_token_id=self.processing_class.pad_token_id, ) - # if reference_output in batch use that otherwise use the reference model - if "reference_output" in batch: - reference_output = batch["reference_output"] + # if ref_output in batch use that otherwise use the reference model + if "ref_output" in batch: + ref_output = batch["ref_output"] else: if self.ref_model is None: with self.null_ref_context(): - reference_output = self.model.generate( + ref_output = self.model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, @@ -1329,7 +962,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) pad_token_id=self.processing_class.pad_token_id, ) else: - reference_output = self.ref_model.generate( + ref_output = self.ref_model.generate( input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, @@ -1340,10 +973,10 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) - reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) + ref_output = pad_to_length(ref_output, self.max_length, self.processing_class.pad_token_id) + ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) - return policy_output_decoded, reference_output_decoded + return policy_output_decoded, ref_output_decoded def prediction_step( self, From 3e67ccbd9db0c3d53b45007602923f9d8bcdef52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 19 Dec 2024 16:36:59 +0000 Subject: [PATCH 02/13] remove columns --- trl/trainer/kto_trainer.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 0d07c81c77..ee4c6b04d4 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -173,12 +173,6 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, ** # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 # 3. https://github.com/LianjiaTech/BELLE/issues/337 - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") - - if not isinstance(completion, str): - raise ValueError(f"completion should be an str but got {type(completion)}") - # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer all_tokens = { "prompt_input_ids": example["prompt_input_ids"], @@ -509,7 +503,6 @@ def make_inputs_require_grad(module, input, output): self.beta = args.beta self.desirable_weight = args.desirable_weight self.undesirable_weight = args.undesirable_weight - self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the @@ -669,6 +662,22 @@ def make_inputs_require_grad(module, input, output): UserWarning, ) + train_dataset= train_dataset.remove_columns( + [ + "prompt", + "completion", + "prompt_input_ids", + "prompt_attention_mask", + "answer_input_ids", + "answer_attention_mask", + "kl_prompt", + "kl_completion", + "kl_label", + "kl_prompt_input_ids", + "kl_prompt_attention_mask", + ] + ) + super().__init__( model=model, args=args, From 3938a52643510167269e912d62406fd41a534ef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 19 Dec 2024 16:50:16 +0000 Subject: [PATCH 03/13] prepare_dataset --- trl/trainer/kto_trainer.py | 54 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index ee4c6b04d4..1cc0f55a40 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -31,7 +31,7 @@ import transformers from accelerate import PartialState from accelerate.utils import is_deepspeed_available -from datasets import Dataset, concatenate_datasets +from datasets import Dataset, IterableDataset, concatenate_datasets from packaging import version from torch.utils.data import DataLoader, SequentialSampler from transformers import ( @@ -513,6 +513,19 @@ def make_inputs_require_grad(module, input, output): # issued. model.warnings_issued["estimate_tokens"] = True + # 4. Handle the dataset - UNCOMMENT WHEN _prepare_dataset READY + # preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) + # if preprocess_dataset: + # train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + # if eval_dataset is not None: + # if isinstance(eval_dataset, dict): + # eval_dataset = { + # key: self._prepare_dataset(dataset, processing_class, args, key) + # for key, dataset in eval_dataset.items() + # } + # else: + # eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): @@ -662,7 +675,7 @@ def make_inputs_require_grad(module, input, output): UserWarning, ) - train_dataset= train_dataset.remove_columns( + train_dataset = train_dataset.remove_columns( [ "prompt", "completion", @@ -712,6 +725,43 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: KTOConfig, + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().local_main_process_first(): + # Extract prompt if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" + dataset = dataset.map(maybe_extract_prompt, **map_kwargs) + + # Unpair the dataset if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Unpairing {dataset_name} dataset" + dataset = maybe_unpair_preference_dataset(dataset, **map_kwargs) + + # Apply the chat template if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" + dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs) + + # HERE + + # # Tokenize the dataset + # if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + # map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + # dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs) + + return dataset + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin From 9c8904599bb298bd836aa69272bbf77f10adc49a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 11:25:14 +0000 Subject: [PATCH 04/13] process dataset --- trl/trainer/kto_trainer.py | 585 ++++++++++++++----------------------- 1 file changed, 223 insertions(+), 362 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 1cc0f55a40..0170a59fff 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -20,10 +20,10 @@ from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy +from dataclasses import dataclass from operator import itemgetter -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union -import numpy as np import torch import torch.amp as amp import torch.nn as nn @@ -31,7 +31,7 @@ import transformers from accelerate import PartialState from accelerate.utils import is_deepspeed_available -from datasets import Dataset, IterableDataset, concatenate_datasets +from datasets import Dataset, IterableDataset from packaging import version from torch.utils.data import DataLoader, SequentialSampler from transformers import ( @@ -47,6 +47,7 @@ TrainingArguments, is_wandb_available, ) +from transformers.data.data_collator import DataCollatorMixin from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available @@ -54,10 +55,10 @@ from ..models import PreTrainedModelWrapper, create_reference_model from .kto_config import KTOConfig from .utils import ( - DPODataCollatorWithPadding, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, + pad, pad_to_length, peft_module_casting_to_bf16, ) @@ -73,174 +74,65 @@ if is_deepspeed_available(): import deepspeed -if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizer -RUNNING_NAME = "running.pt" - - -def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]: - """ - Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions. - For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the same set as the matched - outputs y used to estimate the rewards in that batch, just paired with different x. +@dataclass +class DataCollatorForUnpairedPreference(DataCollatorMixin): """ - batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1] - batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1] - return batch - - -def _tokenize( - batch: dict[str, list[Any]], - tokenizer: "PreTrainedTokenizer", -) -> dict[str, list[Any]]: - """Tokenize a batch from a KTO specific dataset.""" - prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) - prompt_input_ids = prompt_tokenized["input_ids"] - prompt_attention_mask = prompt_tokenized["attention_mask"] - prompt_and_completion = [prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"])] - full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False) - full_input_ids = full_tokenized["input_ids"] - full_attention_mask = full_tokenized["attention_mask"] - - answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids)] - answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask)] - - # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` - full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids)] - # Prepare input tokens for token by token comparison - full_input_ids = [np.array(f) for f in full_input_ids] - for full, concat in zip(full_input_ids, full_concat_input_ids): - if len(full) != len(concat): - raise ValueError( - "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." - ) - - # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens - # can be merged together when tokenizing prompt+answer. This could result - # on the last token from the prompt being different when tokenized on its own - # vs when done as prompt+answer. - response_token_ids_start_idx = [len(p) for p in prompt_input_ids] - - # If tokenized prompt is different than both prompt+answer, then it means the - # last token has changed due to merging. - for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx)): - if not np.array_equal(p, f[:r]): - response_token_ids_start_idx[idx] -= 1 - - prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx)] - prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx)] - - for p, m in zip(prompt_input_ids, prompt_attention_mask): - if len(p) != len(m): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)] - answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)] + Data collator used for unpaired preference data. Inputs are dynamically padded to the maximum length of a batch if + they are not all of the same length. - output = dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - answer_input_ids=answer_input_ids, - answer_attention_mask=answer_attention_mask, - ) - - return output - - -def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict: - """Process tokens of a KTO specific dataset. - - At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation - in case the prompt + completion responses is/are too long. First - we truncate the prompt; if we're still too long, we truncate the completion. - - We also create the labels for the completion responses, which are of length equal to - the sum of the length of the prompt and the completion response, with - label_pad_token_id for the prompt tokens. + Args: + pad_token_id (`int`): + Token ID to use for padding. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl import DataCollatorForUnpairedPreference + >>> collator = DataCollatorForUnpairedPreference(pad_token_id=0) + >>> examples = [ + ... {"prompt_input_ids": [1, 2, 3], "completion_input_ids": [4, 5], "label": True}, + ... {"prompt_input_ids": [7, 8], "completion_input_ids": [9, 10], "label": False} + ... ] + >>> collator(examples) + {'prompt_input_ids': tensor([[1, 2, 3], + [0, 7, 8]]), + 'prompt_attention_mask': tensor([[1, 1, 1], + [0, 1, 1]]), + 'completion_input_ids': tensor([[ 4, 5], + [ 9, 10]]), + 'completion_attention_mask': tensor([[1, 1], + [1, 1]]), + 'labels': tensor([True, False])} + } + ``` """ - prompt = example["prompt"] - completion = example["completion"] - batch = { - f"{kwargs['prefix']}prompt": prompt, - f"{kwargs['prefix']}completion": completion, - f"{kwargs['prefix']}label": example["label"], - } + pad_token_id: int + return_tensors: str = "pt" - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 - - # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer - all_tokens = { - "prompt_input_ids": example["prompt_input_ids"], - "prompt_attention_mask": example["prompt_attention_mask"], - "answer_input_ids": example["answer_input_ids"], - "answer_attention_mask": example["answer_attention_mask"], - } + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + # Convert to tensor + prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples] + prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids] + completion_input_ids = [torch.tensor(example["completion_input_ids"]) for example in examples] + completion_attention_mask = [torch.ones_like(input_ids) for input_ids in completion_input_ids] + labels = torch.tensor([example["label"] for example in examples]) + if "ref_completion_logps" in examples[0]: + ref_completion_logps = torch.tensor([example["ref_completion_logps"] for example in examples]) - # calculate max length by checking if BOS/EOS is already there - max_length = kwargs["max_length"] - bos_token_id = kwargs["tokenizer"].bos_token_id - eos_token_id = kwargs["tokenizer"].eos_token_id - if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: - max_length -= 1 - if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: - max_length -= 1 - - # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["prompt_input_ids", "prompt_attention_mask"]: - if kwargs["truncation_mode"] == "keep_start": - all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] - elif kwargs["truncation_mode"] == "keep_end": - all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] - else: - raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") - - # if that's still too long, truncate the response - if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: - for k in ["answer_input_ids", "answer_attention_mask"]: - all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] - - # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens - batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] - batch[f"{kwargs['prefix']}completion_attention_mask"] = ( - all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] - ) - - # add BOS, which affects both prompt and the full completion - if bos_token_id is not None: - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] - # add EOS, which affects only the full completion - if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: - batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ - eos_token_id - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] + [1] - - batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] - batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ - kwargs["label_pad_token_id"] - ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) - - return batch + # Pad + output = {} + output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left") + output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left") + output["completion_input_ids"] = pad(completion_input_ids, padding_value=self.pad_token_id) + output["completion_attention_mask"] = pad(completion_attention_mask, padding_value=0) + output["labels"] = labels + if "ref_completion_logps" in examples[0]: + output["ref_completion_logps"] = ref_completion_logps + + return output class KTOTrainer(Trainer): @@ -463,24 +355,7 @@ def make_inputs_require_grad(module, input, output): max_completion_length = None if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=processing_class.pad_token_id, - label_pad_token_id=args.label_pad_token_id, - is_encoder_decoder=False, - ) - - if args.remove_unused_columns: - args.remove_unused_columns = False - # warn users - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig" - " we have set it for you, but you should do it yourself in the future.", - UserWarning, - ) - - self.use_dpo_data_collator = True - else: - self.use_dpo_data_collator = False + data_collator = DataCollatorForUnpairedPreference(pad_token_id=processing_class.pad_token_id) if args.disable_dropout: disable_dropout_in_model(model) @@ -514,183 +389,44 @@ def make_inputs_require_grad(module, input, output): model.warnings_issued["estimate_tokens"] = True # 4. Handle the dataset - UNCOMMENT WHEN _prepare_dataset READY - # preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) - # if preprocess_dataset: - # train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") - # if eval_dataset is not None: - # if isinstance(eval_dataset, dict): - # eval_dataset = { - # key: self._prepare_dataset(dataset, processing_class, args, key) - # for key, dataset in eval_dataset.items() - # } - # else: - # eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") - - # Compute that only on the main process for faster data processing. - # see: https://github.com/huggingface/trl/pull/1255 - with PartialState().local_main_process_first(): - # Extract the prompt if needed - train_dataset = train_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" - ) - # Unpair the dataset if needed - train_dataset = maybe_unpair_preference_dataset( - train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" - ) - # Apply the chat template if needed - train_dataset = train_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - desc="Applying chat template to train dataset", - ) - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" - ) - eval_dataset = maybe_unpair_preference_dataset( - eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" - ) - eval_dataset = eval_dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class}, - num_proc=args.dataset_num_proc, - desc="Applying chat template to eval dataset", - ) - - # Tokenize and prepare the training datasets - train_dataset = train_dataset.map( - _tokenize, - batched=True, - fn_kwargs={"tokenizer": self.processing_class}, - num_proc=args.dataset_num_proc, - desc="Tokenizing train dataset", - ) - - fn_kwargs = { - "prefix": "", - "tokenizer": self.processing_class, - "max_length": self.max_length, - "truncation_mode": self.truncation_mode, - "label_pad_token_id": self.label_pad_token_id, - "max_prompt_length": self.max_prompt_length, - "max_completion_length": self.max_completion_length, - } - - train_dataset = train_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - desc="Processing tokenized train dataset", - ) - - # Tokenize and prepare the eval datasets - if eval_dataset is not None: - eval_dataset = eval_dataset.map( - _tokenize, - fn_kwargs={"tokenizer": self.processing_class}, - batched=True, - num_proc=args.dataset_num_proc, - desc="Tokenizing eval dataset", - ) - - eval_dataset = eval_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - desc="Processing tokenized eval dataset", - ) + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") - # Get KL datasets if needed - if args.per_device_train_batch_size <= 1: - raise ValueError( - "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward." - ) + # Calculate dataset desirability balance + num_desirable = max(sum(train_dataset["label"]), 1) + num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary - # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size - # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n) - train_kl_dataset = train_dataset.map( - _get_kl_dataset, - batched=True, - batch_size=args.per_device_train_batch_size, - num_proc=args.dataset_num_proc, - desc="Extracting KL train dataset", - ) + if num_desirable != num_undesirable: + # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306 + des_weight_lower_bound = num_undesirable * self.undesirable_weight / num_desirable + des_weight_upper_bound = des_weight_lower_bound * 1.33 - fn_kwargs["prefix"] = "kl_" - train_kl_dataset = train_kl_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names], - desc="Processing tokenized train KL dataset", - ) + und_weight_upper_bound = num_desirable * self.desirable_weight / num_undesirable + und_weight_lower_bound = und_weight_upper_bound / 1.33 - # merge the datasets - train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1) - - if eval_dataset is not None: - # Get KL dataset - eval_kl_dataset = eval_dataset.map( - _get_kl_dataset, - batched=True, - batch_size=args.per_device_train_batch_size, - num_proc=args.dataset_num_proc, - desc="Extracting eval KL dataset", - ) + # Check if weights are within bounds (use exact values for logic) + des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound + und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound - eval_kl_dataset = eval_kl_dataset.map( - _process_tokens, - fn_kwargs=fn_kwargs, - num_proc=args.dataset_num_proc, - remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names], - desc="Processing tokenized eval KL dataset", + # Display warning with rounded bounds using f-strings + if not (des_weight_in_range or und_weight_in_range): + warnings.warn( + "You have different amounts of desirable/positive and undesirable/negative examples, but the " + "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based on " + f"your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound:.2f}, " + f"{des_weight_upper_bound:.2f}] OR undesirable_weight in [{und_weight_lower_bound:.2f}, " + f"{und_weight_upper_bound:.2f}] (but NOT BOTH). See the documentation on how to optimally set " + "these weights.", + UserWarning, ) - # merge the datasets - eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1) - - # calculate dataset desirability balance - num_desirable = max(sum(train_dataset["label"]), 1) - num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary - - if num_desirable != num_undesirable: - # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306 - des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2) - des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2) - und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2) - und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2) - - des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound - und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound - - if not (des_weight_in_range or und_weight_in_range): - warnings.warn( - "You have different amounts of desirable/positive and undesirable/negative examples but the " - "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based " - f"on your data, we recommend EITHER " - f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or " - f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). " - "See the documentation on how to optimally set these weights.", - UserWarning, - ) - - train_dataset = train_dataset.remove_columns( - [ - "prompt", - "completion", - "prompt_input_ids", - "prompt_attention_mask", - "answer_input_ids", - "answer_attention_mask", - "kl_prompt", - "kl_completion", - "kl_label", - "kl_prompt_input_ids", - "kl_prompt_attention_mask", - ] - ) - super().__init__( model=model, args=args, @@ -753,15 +489,80 @@ def _prepare_dataset( map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs) - # HERE - - # # Tokenize the dataset - # if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` - # map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - # dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs) + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + dataset = dataset.map( + self.tokenize_row, + remove_columns=["prompt", "completion"], + fn_kwargs={ + "processing_class": processing_class, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) + "add_special_tokens": False, + }, + **map_kwargs, + ) return dataset + @staticmethod + def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): + """ + Tokenize a row of the dataset. + + Args: + features (`dict[str, str]`): + Row of the dataset, should contain the keys `"prompt"`, `"completion"` and `"label"`. + processing_class (`PreTrainedTokenizerBase`): + Processing class used to process the data. + max_prompt_length (`int` or `None`): + Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. + max_completion_length (`int` or `None`): + Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + add_special_tokens (`bool`): + Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, + the prompt sequence will have a bos token prepended and an eos token appended. In any case, the + completion sequences will have an eos token appended. + + Returns: + `dict[str, list[int]]`: + Tokenized sequences with the keys `"prompt_input_ids"`, `"completion_input_ids"`. + + Example: + ```python + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + >>> features = {"prompt": "The sky is", "completion": " blue.", "label": 1} + >>> KTOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False) + {'prompt_input_ids': [785, 12884, 374], 'completion_input_ids': [6303, 13, 151643]} + ``` + """ + tokenizer = processing_class # the processing class is a tokenizer + prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] + completion_input_ids = tokenizer(features["completion"], add_special_tokens=False)["input_ids"] + + # Add special tokens (typically for encoder-decoder models) + if add_special_tokens: + if tokenizer.bos_token_id is not None: + prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids + if tokenizer.eos_token_id is not None: + prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] + completion_input_ids = completion_input_ids + [tokenizer.eos_token_id] + + # Truncate prompt and completion sequences + if max_prompt_length is not None: + prompt_input_ids = prompt_input_ids[-max_prompt_length:] + if max_completion_length is not None: + completion_input_ids = completion_input_ids[:max_completion_length] + + return { + "prompt_input_ids": prompt_input_ids, + "completion_input_ids": completion_input_ids, + } + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin @@ -793,6 +594,14 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): model.eval() return model + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In KTOTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by `DataCollatorForUnpairedPreference`, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt_input_ids", "completion_input_ids", "ref_completion_logps", "label"] + @contextmanager def null_ref_context(self): """Context manager for handling null reference model (that is, peft adapter manipulation).""" @@ -832,8 +641,8 @@ def forward( per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) completion_logps = (per_token_logps * loss_mask).sum(-1) - chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] - rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] + chosen_idx = torch.nonzero(batch["labels"], as_tuple=True)[0] + rejected_idx = torch.nonzero(~batch["labels"], as_tuple=True)[0] chosen_logps = completion_logps[chosen_idx] rejected_logps = completion_logps[rejected_idx] @@ -914,7 +723,59 @@ def get_batch_loss_metrics( ): """Compute the KTO loss and other metrics for the given batch of inputs for train or test.""" metrics = {} - batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + batch = {k: v.to(self.accelerator.device) for k, v in batch.items()} + + _batch = {} + # Concat the prompt and completion input ids, and flush left + + input_ids = torch.cat((batch["prompt_input_ids"], batch["completion_input_ids"]), dim=1) + attention_mask = torch.cat((batch["prompt_attention_mask"], batch["completion_attention_mask"]), dim=1) + prompt_labels = torch.ones_like(batch["prompt_input_ids"]) * self.label_pad_token_id + completions_labels = batch["completion_input_ids"].clone() + # replace the label pad token with -100 + completions_labels[~batch["completion_attention_mask"].bool()] = -100 + labels = torch.cat((prompt_labels, completions_labels), dim=1) + + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(attention_mask.size(0)): + first_one_idx = torch.nonzero(attention_mask[i])[0].item() + input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) + attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) + labels[i] = torch.roll(labels[i], shifts=-first_one_idx) + + _batch["completion_input_ids"] = input_ids + _batch["completion_attention_mask"] = attention_mask + _batch["completion_labels"] = labels + + # Roll completion to create mismatch pairs + completion_input_ids = torch.roll(batch["completion_input_ids"], shifts=1, dims=0) + completion_attention_mask = torch.roll(batch["completion_attention_mask"], shifts=1, dims=0) + + input_ids = torch.cat((batch["prompt_input_ids"], completion_input_ids), dim=1) + attention_mask = torch.cat((batch["prompt_attention_mask"], completion_attention_mask), dim=1) + # prompt_labels = torch.ones_like(batch["prompt_input_ids"]) * self.label_pad_token_id + completions_labels = completion_input_ids.clone() + # replace the label pad token with -100 + completions_labels[~batch["completion_attention_mask"].bool()] = -100 + labels = torch.cat((prompt_labels, completions_labels), dim=1) + + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(attention_mask.size(0)): + first_one_idx = torch.nonzero(attention_mask[i])[0].item() + input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) + attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) + labels[i] = torch.roll(labels[i], shifts=-first_one_idx) + + _batch["kl_completion_input_ids"] = input_ids + _batch["kl_completion_attention_mask"] = attention_mask + _batch["kl_completion_labels"] = labels + _batch["labels"] = batch["labels"] + + batch = _batch model_output = self.forward(model, batch) with torch.no_grad(): From cffd592e34c13e354c5966576b5a4b7c32370bca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 11:26:41 +0000 Subject: [PATCH 05/13] remove test --- tests/test_kto_trainer.py | 86 --------------------------------------- 1 file changed, 86 deletions(-) diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index 5cb311d272..72a6478384 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -22,7 +22,6 @@ from transformers.testing_utils import require_peft from trl import KTOConfig, KTOTrainer -from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize from .testing_utils import require_no_wandb @@ -122,91 +121,6 @@ def test_kto_trainer_with_ref_model_is_model(self): train_dataset=dummy_dataset["train"], ) - def test_tokenize_and_process_tokens(self): - with tempfile.TemporaryDirectory() as tmp_dir: - training_args = KTOConfig( - output_dir=tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - remove_unused_columns=False, - gradient_accumulation_steps=1, - learning_rate=9e-1, - eval_strategy="steps", - beta=0.1, - report_to="none", - ) - - dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") - - trainer = KTOTrainer( - model=self.model, - ref_model=self.ref_model, - args=training_args, - processing_class=self.tokenizer, - train_dataset=dummy_dataset["train"], - eval_dataset=dummy_dataset["test"], - ) - - train_dataset = dummy_dataset["train"] - tokenized_dataset = train_dataset.map( - _tokenize, - fn_kwargs={"tokenizer": trainer.tokenizer}, - batched=True, - batch_size=2, - ) - self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"]) - self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"]) - self.assertListEqual(tokenized_dataset["label"], train_dataset["label"]) - self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) - self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1]) - - # Test corruption of (prompt, completion) pairs for KL dataset - for batch_size in [2, 3]: - tokenized_kl_dataset = tokenized_dataset.map(_get_kl_dataset, batched=True, batch_size=batch_size) - - # Verify that the "answer_input_ids" have been modified, meaning the new "answer_input_ids" differ - # from the original ones. However, when the length of the dataset modulo batch_size equals 1, - # the last batch remains unaltered. This is a rare scenario that does not impact the training - # process, so we exclude it from testing by iterating only up to len - 1. - for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1): - self.assertListEqual( - tokenized_dataset["prompt_input_ids"][i], - tokenized_kl_dataset["prompt_input_ids"][i], - ) - self.assertListEqual( - tokenized_dataset["prompt_attention_mask"][i], - tokenized_kl_dataset["prompt_attention_mask"][i], - ) - self.assertNotEqual( - tokenized_dataset["answer_input_ids"][i], - tokenized_kl_dataset["answer_input_ids"][i], - ) - - fn_kwargs = { - "prefix": "", - "is_encoder_decoder": trainer.is_encoder_decoder, - "tokenizer": trainer.tokenizer, - "max_length": trainer.max_length, - "truncation_mode": trainer.truncation_mode, - "label_pad_token_id": trainer.label_pad_token_id, - "max_prompt_length": trainer.max_prompt_length, - } - processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) - self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"]) - self.assertListEqual(processed_dataset["completion"], train_dataset["completion"]) - self.assertListEqual(processed_dataset["label"], train_dataset["label"]) - self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual( - processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645] - ) - self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1]) - self.assertListEqual( - processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645] - ) - def test_kto_trainer_without_providing_ref_model(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = KTOConfig( From 971addd41700d02a0197675eabe49e1801c4ab99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 14:11:24 +0000 Subject: [PATCH 06/13] current progress --- trl/trainer/kto_trainer.py | 98 +++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 33 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 0170a59fff..f04a05b306 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -617,44 +617,40 @@ def null_ref_context(self): def forward( self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - kl_logps = None - with torch.no_grad(): kl_logits = model( - input_ids=batch["kl_completion_input_ids"], attention_mask=batch["kl_completion_attention_mask"] + input_ids=batch["kl_input_ids"], attention_mask=batch["kl_attention_mask"] ).logits - labels = batch["kl_completion_labels"][:, 1:].clone() - logits = kl_logits[:, :-1, :] + labels = batch["kl_labels"][:, 1:].clone() + logits = kl_logits[:, :-1] loss_mask = labels != self.label_pad_token_id labels[labels == self.label_pad_token_id] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) kl_logps = (per_token_logps * loss_mask).sum(-1) - outputs = model(batch["completion_input_ids"], attention_mask=batch["completion_attention_mask"]) - completion_logits = outputs.logits - - labels = batch["completion_labels"][:, 1:].clone() - logits = completion_logits[:, :-1, :] + logits = model(batch["input_ids"], attention_mask=batch["attention_mask"]).logits + labels = batch["labels"][:, 1:].clone() + logits = logits[:, :-1] loss_mask = labels != self.label_pad_token_id labels[labels == self.label_pad_token_id] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - completion_logps = (per_token_logps * loss_mask).sum(-1) + logps = (per_token_logps * loss_mask).sum(-1) - chosen_idx = torch.nonzero(batch["labels"], as_tuple=True)[0] - rejected_idx = torch.nonzero(~batch["labels"], as_tuple=True)[0] + chosen_idx = torch.nonzero(batch["desirable"], as_tuple=True)[0] + rejected_idx = torch.nonzero(~batch["desirable"], as_tuple=True)[0] - chosen_logps = completion_logps[chosen_idx] - rejected_logps = completion_logps[rejected_idx] + chosen_logps = logps[chosen_idx] + rejected_logps = logps[rejected_idx] - chosen_logits = completion_logits[chosen_idx] - rejected_logits = completion_logits[rejected_idx] + chosen_logits_sum = logits[chosen_idx].nansum() + rejected_logits_sum = logits[rejected_idx].nansum() return { "chosen_logps": chosen_logps, "rejected_logps": rejected_logps, - "chosen_logits": chosen_logits, - "rejected_logits": rejected_logits, + "chosen_logits_sum": chosen_logits_sum, + "rejected_logits_sum": rejected_logits_sum, "kl_logps": kl_logps, } @@ -687,12 +683,12 @@ def kto_loss( kl = self.accelerator.gather(kl).mean().clamp(min=0) # Chosen losses + assert chosen_logps.shape[0] == ref_chosen_logps.shape[0] if chosen_logps.shape[0] != 0 or ref_chosen_logps.shape[0] != 0: chosen_logratios = chosen_logps - ref_chosen_logps - # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) + # Eqn (8) of the KTO paper (https://huggingface.co/papers/2402.01306) chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) chosen_rewards = self.beta * chosen_logratios.detach() - else: # lists can't be empty -- if they are, then accelerate.gather will hang chosen_losses = torch.Tensor([]).to(self.accelerator.device) @@ -702,7 +698,43 @@ def kto_loss( if rejected_logps.shape[0] != 0 or ref_rejected_logps.shape[0] != 0: rejected_logratios = rejected_logps - ref_rejected_logps rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + rejected_rewards = self.beta * rejected_logratios.detach() + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + rejected_losses = torch.Tensor([]).to(self.accelerator.device) + rejected_rewards = torch.Tensor([]).to(self.accelerator.device) + losses = torch.cat( + (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), + dim=0, + ) + + return losses, chosen_rewards, rejected_rewards, kl + + def meta_kto_loss(self, model, batch): + model_output = self.forward(model, batch) + with torch.no_grad(): + ref_model_output = self.forward(self.ref_model, batch) + + kl = (model_output["kl_logps"] - model_output["kl_logps"]).mean().detach() + kl = self.accelerator.gather(kl).mean().clamp(min=0) + + # Chosen losses + assert model_output["chosen_logps"].shape[0] == ref_model_output["chosen_logps"].shape[0] + if model_output["chosen_logps"].shape[0] != 0 or ref_model_output["chosen_logps"].shape[0] != 0: + chosen_logratios = model_output["chosen_logps"] - ref_model_output["chosen_logps"] + # Eqn (8) of the KTO paper (https://huggingface.co/papers/2402.01306) + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + chosen_rewards = self.beta * chosen_logratios.detach() + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + chosen_losses = torch.Tensor([]).to(self.accelerator.device) + chosen_rewards = torch.Tensor([]).to(self.accelerator.device) + + # Rejected losses + if model_output["rejected_logps"].shape[0] != 0 or ref_model_output["rejected_logps"].shape[0] != 0: + rejected_logratios = model_output["rejected_logps"] - ref_model_output["rejected_logps"] + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) rejected_rewards = self.beta * rejected_logratios.detach() else: # lists can't be empty -- if they are, then accelerate.gather will hang @@ -711,11 +743,12 @@ def kto_loss( losses = torch.cat( (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), - 0, + dim=0, ) return losses, chosen_rewards, rejected_rewards, kl + def get_batch_loss_metrics( self, model, @@ -732,8 +765,7 @@ def get_batch_loss_metrics( attention_mask = torch.cat((batch["prompt_attention_mask"], batch["completion_attention_mask"]), dim=1) prompt_labels = torch.ones_like(batch["prompt_input_ids"]) * self.label_pad_token_id completions_labels = batch["completion_input_ids"].clone() - # replace the label pad token with -100 - completions_labels[~batch["completion_attention_mask"].bool()] = -100 + completions_labels[~batch["completion_attention_mask"].bool()] = -100 # replace the label pad token with -100 labels = torch.cat((prompt_labels, completions_labels), dim=1) # Flush left to reduce the memory usage @@ -745,9 +777,9 @@ def get_batch_loss_metrics( attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) labels[i] = torch.roll(labels[i], shifts=-first_one_idx) - _batch["completion_input_ids"] = input_ids - _batch["completion_attention_mask"] = attention_mask - _batch["completion_labels"] = labels + _batch["input_ids"] = input_ids + _batch["attention_mask"] = attention_mask + _batch["labels"] = labels # Roll completion to create mismatch pairs completion_input_ids = torch.roll(batch["completion_input_ids"], shifts=1, dims=0) @@ -770,10 +802,10 @@ def get_batch_loss_metrics( attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) labels[i] = torch.roll(labels[i], shifts=-first_one_idx) - _batch["kl_completion_input_ids"] = input_ids - _batch["kl_completion_attention_mask"] = attention_mask - _batch["kl_completion_labels"] = labels - _batch["labels"] = batch["labels"] + _batch["kl_input_ids"] = input_ids + _batch["kl_attention_mask"] = attention_mask + _batch["kl_labels"] = labels + _batch["desirable"] = batch["labels"] batch = _batch @@ -803,7 +835,7 @@ def get_batch_loss_metrics( self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() ) metrics["logits/chosen_sum"] = ( - self.accelerator.gather(model_output["chosen_logits"].nansum()).nansum().item() + self.accelerator.gather(model_output["chosen_logits_sum"]).nansum().item() ) metrics["count/chosen"] = all_num_chosen @@ -813,7 +845,7 @@ def get_batch_loss_metrics( self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() ) metrics["logits/rejected_sum"] = ( - self.accelerator.gather(model_output["rejected_logits"].nansum()).nansum().item() + self.accelerator.gather(model_output["rejected_logits_sum"]).nansum().item() ) metrics["count/rejected"] = all_num_rejected From 968fff22614a068d815de217fd4015c29baaeef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 15:06:59 +0000 Subject: [PATCH 07/13] max length --- trl/trainer/kto_trainer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index f04a05b306..cf2887c4e8 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -777,6 +777,12 @@ def get_batch_loss_metrics( attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) labels[i] = torch.roll(labels[i], shifts=-first_one_idx) + # Truncate right + if self.args.max_length is not None: + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + labels = labels[:, : self.args.max_length] + _batch["input_ids"] = input_ids _batch["attention_mask"] = attention_mask _batch["labels"] = labels @@ -802,6 +808,12 @@ def get_batch_loss_metrics( attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) labels[i] = torch.roll(labels[i], shifts=-first_one_idx) + # Truncate right + if self.args.max_length is not None: + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + labels = labels[:, : self.args.max_length] + _batch["kl_input_ids"] = input_ids _batch["kl_attention_mask"] = attention_mask _batch["kl_labels"] = labels @@ -809,6 +821,8 @@ def get_batch_loss_metrics( batch = _batch + + model_output = self.forward(model, batch) with torch.no_grad(): ref_model_output = self.forward(self.ref_model, batch) From c8c00dda6e2f0fc7d1eaec403ad218696fd0097e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 17:09:43 +0000 Subject: [PATCH 08/13] here we are --- trl/trainer/kto_trainer.py | 200 +++++++++++++------------------------ 1 file changed, 67 insertions(+), 133 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index cf2887c4e8..7ba14f0830 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -617,42 +617,81 @@ def null_ref_context(self): def forward( self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - with torch.no_grad(): - kl_logits = model( - input_ids=batch["kl_input_ids"], attention_mask=batch["kl_attention_mask"] - ).logits + output = {} - labels = batch["kl_labels"][:, 1:].clone() - logits = kl_logits[:, :-1] - loss_mask = labels != self.label_pad_token_id - labels[labels == self.label_pad_token_id] = 0 - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - kl_logps = (per_token_logps * loss_mask).sum(-1) + # Concat the prompt and completion input ids + input_ids = torch.cat((batch["prompt_input_ids"], batch["completion_input_ids"]), dim=1) + attention_mask = torch.cat((batch["prompt_attention_mask"], batch["completion_attention_mask"]), dim=1) + prompt_labels = torch.ones_like(batch["prompt_input_ids"]) * self.label_pad_token_id + completions_labels = batch["completion_input_ids"].clone() + completions_labels[~batch["completion_attention_mask"].bool()] = -100 # replace the label pad token with -100 + labels = torch.cat((prompt_labels, completions_labels), dim=1) - logits = model(batch["input_ids"], attention_mask=batch["attention_mask"]).logits - labels = batch["labels"][:, 1:].clone() + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(attention_mask.size(0)): + first_one_idx = torch.nonzero(attention_mask[i])[0].item() + input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) + attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) + labels[i] = torch.roll(labels[i], shifts=-first_one_idx) + + # Truncate right + if self.args.max_length is not None: + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + labels = labels[:, : self.args.max_length] + + logits = model(input_ids=input_ids, attention_mask=attention_mask).logits logits = logits[:, :-1] + labels = labels[:, 1:].clone() loss_mask = labels != self.label_pad_token_id - labels[labels == self.label_pad_token_id] = 0 + labels[~loss_mask] = 0 per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) logps = (per_token_logps * loss_mask).sum(-1) + output["chosen_logps"] = logps[batch["labels"]] + output["rejected_logps"] = logps[~batch["labels"]] + output["sum_chosen_logits"] = logits[batch["labels"]].nansum() + output["sum_rejected_logits"] = logits[~batch["labels"]].nansum() - chosen_idx = torch.nonzero(batch["desirable"], as_tuple=True)[0] - rejected_idx = torch.nonzero(~batch["desirable"], as_tuple=True)[0] + # Roll completion to create mismatch pairs + kl_completion_input_ids = torch.roll(batch["completion_input_ids"], shifts=1, dims=0) + kl_completion_attention_mask = torch.roll(batch["completion_attention_mask"], shifts=1, dims=0) + + # Concat the prompt and completion input ids + kl_input_ids = torch.cat((batch["prompt_input_ids"], kl_completion_input_ids), dim=1) + kl_attention_mask = torch.cat((batch["prompt_attention_mask"], kl_completion_attention_mask), dim=1) + kl_completions_labels = kl_completion_input_ids.clone() + kl_completions_labels[ + ~batch["completion_attention_mask"].bool() + ] = -100 # replace the label pad token with -100 + kl_labels = torch.cat((prompt_labels, kl_completions_labels), dim=1) + + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(kl_attention_mask.size(0)): + first_one_idx = torch.nonzero(kl_attention_mask[i])[0].item() + kl_input_ids[i] = torch.roll(kl_input_ids[i], shifts=-first_one_idx) + kl_attention_mask[i] = torch.roll(kl_attention_mask[i], shifts=-first_one_idx) + kl_labels[i] = torch.roll(kl_labels[i], shifts=-first_one_idx) - chosen_logps = logps[chosen_idx] - rejected_logps = logps[rejected_idx] + # Truncate right + if self.args.max_length is not None: + kl_input_ids = kl_input_ids[:, : self.args.max_length] + kl_attention_mask = kl_attention_mask[:, : self.args.max_length] + kl_labels = kl_labels[:, : self.args.max_length] - chosen_logits_sum = logits[chosen_idx].nansum() - rejected_logits_sum = logits[rejected_idx].nansum() + with torch.no_grad(): + kl_logits = model(input_ids=kl_input_ids, attention_mask=kl_attention_mask).logits + logits = kl_logits[:, :-1] + labels = kl_labels[:, 1:].clone() + loss_mask = labels != self.label_pad_token_id + labels[~loss_mask] = 0 + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + output["kl_logps"] = (per_token_logps * loss_mask).sum(-1) - return { - "chosen_logps": chosen_logps, - "rejected_logps": rejected_logps, - "chosen_logits_sum": chosen_logits_sum, - "rejected_logits_sum": rejected_logits_sum, - "kl_logps": kl_logps, - } + return output def kto_loss( self, @@ -711,44 +750,6 @@ def kto_loss( return losses, chosen_rewards, rejected_rewards, kl - def meta_kto_loss(self, model, batch): - model_output = self.forward(model, batch) - with torch.no_grad(): - ref_model_output = self.forward(self.ref_model, batch) - - kl = (model_output["kl_logps"] - model_output["kl_logps"]).mean().detach() - kl = self.accelerator.gather(kl).mean().clamp(min=0) - - # Chosen losses - assert model_output["chosen_logps"].shape[0] == ref_model_output["chosen_logps"].shape[0] - if model_output["chosen_logps"].shape[0] != 0 or ref_model_output["chosen_logps"].shape[0] != 0: - chosen_logratios = model_output["chosen_logps"] - ref_model_output["chosen_logps"] - # Eqn (8) of the KTO paper (https://huggingface.co/papers/2402.01306) - chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) - chosen_rewards = self.beta * chosen_logratios.detach() - else: - # lists can't be empty -- if they are, then accelerate.gather will hang - chosen_losses = torch.Tensor([]).to(self.accelerator.device) - chosen_rewards = torch.Tensor([]).to(self.accelerator.device) - - # Rejected losses - if model_output["rejected_logps"].shape[0] != 0 or ref_model_output["rejected_logps"].shape[0] != 0: - rejected_logratios = model_output["rejected_logps"] - ref_model_output["rejected_logps"] - rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) - rejected_rewards = self.beta * rejected_logratios.detach() - else: - # lists can't be empty -- if they are, then accelerate.gather will hang - rejected_losses = torch.Tensor([]).to(self.accelerator.device) - rejected_rewards = torch.Tensor([]).to(self.accelerator.device) - - losses = torch.cat( - (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), - dim=0, - ) - - return losses, chosen_rewards, rejected_rewards, kl - - def get_batch_loss_metrics( self, model, @@ -758,71 +759,6 @@ def get_batch_loss_metrics( metrics = {} batch = {k: v.to(self.accelerator.device) for k, v in batch.items()} - _batch = {} - # Concat the prompt and completion input ids, and flush left - - input_ids = torch.cat((batch["prompt_input_ids"], batch["completion_input_ids"]), dim=1) - attention_mask = torch.cat((batch["prompt_attention_mask"], batch["completion_attention_mask"]), dim=1) - prompt_labels = torch.ones_like(batch["prompt_input_ids"]) * self.label_pad_token_id - completions_labels = batch["completion_input_ids"].clone() - completions_labels[~batch["completion_attention_mask"].bool()] = -100 # replace the label pad token with -100 - labels = torch.cat((prompt_labels, completions_labels), dim=1) - - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - for i in range(attention_mask.size(0)): - first_one_idx = torch.nonzero(attention_mask[i])[0].item() - input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) - attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) - labels[i] = torch.roll(labels[i], shifts=-first_one_idx) - - # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - labels = labels[:, : self.args.max_length] - - _batch["input_ids"] = input_ids - _batch["attention_mask"] = attention_mask - _batch["labels"] = labels - - # Roll completion to create mismatch pairs - completion_input_ids = torch.roll(batch["completion_input_ids"], shifts=1, dims=0) - completion_attention_mask = torch.roll(batch["completion_attention_mask"], shifts=1, dims=0) - - input_ids = torch.cat((batch["prompt_input_ids"], completion_input_ids), dim=1) - attention_mask = torch.cat((batch["prompt_attention_mask"], completion_attention_mask), dim=1) - # prompt_labels = torch.ones_like(batch["prompt_input_ids"]) * self.label_pad_token_id - completions_labels = completion_input_ids.clone() - # replace the label pad token with -100 - completions_labels[~batch["completion_attention_mask"].bool()] = -100 - labels = torch.cat((prompt_labels, completions_labels), dim=1) - - # Flush left to reduce the memory usage - # [[0, 0, x, x, x, x], -> [[x, x, x, x], - # [0, x, x, x, 0, 0]] [x, x, x, 0]] - for i in range(attention_mask.size(0)): - first_one_idx = torch.nonzero(attention_mask[i])[0].item() - input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) - attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) - labels[i] = torch.roll(labels[i], shifts=-first_one_idx) - - # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - labels = labels[:, : self.args.max_length] - - _batch["kl_input_ids"] = input_ids - _batch["kl_attention_mask"] = attention_mask - _batch["kl_labels"] = labels - _batch["desirable"] = batch["labels"] - - batch = _batch - - - model_output = self.forward(model, batch) with torch.no_grad(): ref_model_output = self.forward(self.ref_model, batch) @@ -848,9 +784,7 @@ def get_batch_loss_metrics( metrics["logps/chosen_sum"] = ( self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() ) - metrics["logits/chosen_sum"] = ( - self.accelerator.gather(model_output["chosen_logits_sum"]).nansum().item() - ) + metrics["logits/chosen_sum"] = self.accelerator.gather(model_output["sum_chosen_logits"]).nansum().item() metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: @@ -859,7 +793,7 @@ def get_batch_loss_metrics( self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() ) metrics["logits/rejected_sum"] = ( - self.accelerator.gather(model_output["rejected_logits_sum"]).nansum().item() + self.accelerator.gather(model_output["sum_rejected_logits"]).nansum().item() ) metrics["count/rejected"] = all_num_rejected From 3cdc3a8dc894b5b54b78e19cb58fbd1257d491f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 22:04:14 +0000 Subject: [PATCH 09/13] test --- tests/test_collators.py | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_collators.py diff --git a/tests/test_collators.py b/tests/test_collators.py new file mode 100644 index 0000000000..57faa3a298 --- /dev/null +++ b/tests/test_collators.py @@ -0,0 +1,46 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from trl.trainer.kto_trainer import DataCollatorForUnpairedPreference + + +class TestDataCollatorForUnpairedPreference(unittest.TestCase): + def setUp(self): + self.collator = DataCollatorForUnpairedPreference(pad_token_id=0) + + def assertTensorEqual(self, tensor1, tensor2): + self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}") + + def test_padding_behavior(self): + examples = [ + {"prompt_input_ids": [1, 2, 3], "completion_input_ids": [4, 5], "label": True}, + {"prompt_input_ids": [6, 7], "completion_input_ids": [8, 9, 10], "label": False}, + ] + output = self.collator.torch_call(examples) + + expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 6, 7]]) + expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) + expected_completion_input_ids = torch.tensor([[4, 5, 0], [8, 9, 10]]) + expected_completion_attention_mask = torch.tensor([[1, 1, 0], [1, 1, 1]]) + expected_labels = torch.tensor([True, False]) + + self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids) + self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask) + self.assertTensorEqual(output["completion_input_ids"], expected_completion_input_ids) + self.assertTensorEqual(output["completion_attention_mask"], expected_completion_attention_mask) + self.assertTensorEqual(output["labels"], expected_labels) From c0bc747df3022d93b8741d1d03418c8aa8fbe012 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 23:09:43 +0000 Subject: [PATCH 10/13] log --- trl/trainer/kto_trainer.py | 113 +++++++++++-------------------------- 1 file changed, 33 insertions(+), 80 deletions(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 7ba14f0830..c1d7f2887a 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -320,6 +320,7 @@ def make_inputs_require_grad(module, input, output): self.model_adapter_name = model_adapter_name self.ref_adapter_name = ref_adapter_name + # Get the reference model if ref_model: self.ref_model = ref_model elif self.is_peft_model or False: @@ -328,48 +329,20 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = create_reference_model(model) - if processing_class is None: - raise ValueError( - "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" - ) - if args.max_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init" - " it will be set to `512` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_length = 512 - if args.max_length is not None: - max_length = args.max_length - - if args.max_prompt_length is None: - warnings.warn( - "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init" - " it will be set to `128` by default, but you should do it yourself in the future.", - UserWarning, - ) - max_prompt_length = 128 - if args.max_prompt_length is not None: - max_prompt_length = args.max_prompt_length - - max_completion_length = None - - if data_collator is None: - data_collator = DataCollatorForUnpairedPreference(pad_token_id=processing_class.pad_token_id) - + # Disable dropout if needed if args.disable_dropout: disable_dropout_in_model(model) if self.ref_model is not None: disable_dropout_in_model(self.ref_model) - self.max_length = max_length + # Define the data collator + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + if data_collator is None: + data_collator = DataCollatorForUnpairedPreference(pad_token_id=self.padding_value) + self.generate_during_eval = args.generate_during_eval self.label_pad_token_id = args.label_pad_token_id - self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id - self.max_prompt_length = max_prompt_length - self.truncation_mode = args.truncation_mode - self.max_completion_length = max_completion_length - self.processing_class = processing_class + self.max_length = args.max_length # metric self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -388,7 +361,7 @@ def make_inputs_require_grad(module, input, output): # issued. model.warnings_issued["estimate_tokens"] = True - # 4. Handle the dataset - UNCOMMENT WHEN _prepare_dataset READY + # Dataset preparation train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") if eval_dataset is not None: if isinstance(eval_dataset, dict): @@ -399,7 +372,7 @@ def make_inputs_require_grad(module, input, output): else: eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") - # Calculate dataset desirability balance + # Calculate dataset desirability balance and display warning if weights are not in an ideal range num_desirable = max(sum(train_dataset["label"]), 1) num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary @@ -780,22 +753,28 @@ def get_batch_loss_metrics( all_num_rejected = self.accelerator.gather(num_rejected).sum().item() if all_num_chosen > 0: - metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() - metrics["logps/chosen_sum"] = ( - self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() + metrics["rewards/chosen"] = ( + self.accelerator.gather(chosen_rewards.nansum()).nansum().item() / all_num_chosen + ) + metrics["logps/chosen"] = ( + self.accelerator.gather(model_output["chosen_logps"].nansum()).nansum().item() / all_num_chosen + ) + metrics["logits/chosen"] = ( + self.accelerator.gather(model_output["sum_chosen_logits"]).nansum().item() / all_num_chosen ) - metrics["logits/chosen_sum"] = self.accelerator.gather(model_output["sum_chosen_logits"]).nansum().item() - metrics["count/chosen"] = all_num_chosen if all_num_rejected > 0: - metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() - metrics["logps/rejected_sum"] = ( - self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() + metrics["rewards/rejected"] = ( + self.accelerator.gather(rejected_rewards.nansum()).nansum().item() / all_num_rejected ) - metrics["logits/rejected_sum"] = ( - self.accelerator.gather(model_output["sum_rejected_logits"]).nansum().item() + metrics["logps/rejected"] = ( + self.accelerator.gather(model_output["rejected_logps"].nansum()).nansum().item() / all_num_rejected ) - metrics["count/rejected"] = all_num_rejected + metrics["logits/rejected"] = ( + self.accelerator.gather(model_output["sum_rejected_logits"]).nansum().item() / all_num_rejected + ) + + metrics["rewards/margins"] = metrics["rewards/chosen"] - metrics["rewards/rejected"] loss = losses.nanmean() @@ -845,7 +824,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) # if ref_output in batch use that otherwise use the reference model @@ -859,7 +838,7 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) else: ref_output = self.ref_model.generate( @@ -867,13 +846,13 @@ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True, - pad_token_id=self.processing_class.pad_token_id, + pad_token_id=self.padding_value, ) - policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) - ref_output = pad_to_length(ref_output, self.max_length, self.processing_class.pad_token_id) + ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) return policy_output_decoded, ref_output_decoded @@ -971,37 +950,11 @@ def evaluation_loop( return initial_output def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`dict[str, float]`): - The values to log. - start_time (`float` or `None`, *optional*, defaults to `None`): - Start time of the training. - """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" - # train metrics should have no prefix, eval should have 'eval_' - prefix = "eval_" if train_eval == "eval" else "" - # accumulate average metrics from sums and lengths - for split in ["chosen", "rejected"]: - if f"count/{split}" in self._stored_metrics[train_eval]: - count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() - for metric in ["rewards", "logps", "logits"]: - logs[f"{prefix}{metric}/{split}"] = ( - torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() - / count_sum - ) - # delete obsolete metric - del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] - del self._stored_metrics[train_eval][f"count/{split}"] - # calculate reward margin - if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: - logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): - logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + logs[key] = sum(metrics) / len(metrics) del self._stored_metrics[train_eval] if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): From 4acea256a75fceb543d75d57ad71ffc18ef71f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 23:46:21 +0000 Subject: [PATCH 11/13] disable dropout --- trl/trainer/kto_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index c1d7f2887a..b856af6672 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -329,7 +329,7 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = create_reference_model(model) - # Disable dropout if needed + # Disable dropout in the model and reference model if args.disable_dropout: disable_dropout_in_model(model) if self.ref_model is not None: From 95722ac94f6848d56892899a6c49649a1f8821b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 23:47:07 +0000 Subject: [PATCH 12/13] disable dropout --- trl/trainer/kto_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py index 563d0cdbc9..e5feb2dbad 100644 --- a/trl/trainer/kto_config.py +++ b/trl/trainer/kto_config.py @@ -77,7 +77,7 @@ class KTOConfig(TrainingArguments): dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. + Whether to disable dropout in the model and reference model. """ learning_rate: float = 1e-6 From a543ac40841e3cbd09a61dc520a711dc4c093ded Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 20 Dec 2024 23:58:31 +0000 Subject: [PATCH 13/13] truncation --- trl/trainer/kto_config.py | 9 +++------ trl/trainer/kto_trainer.py | 30 ++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py index e5feb2dbad..d1a7bc78c3 100644 --- a/trl/trainer/kto_config.py +++ b/trl/trainer/kto_config.py @@ -32,13 +32,11 @@ class KTOConfig(TrainingArguments): Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`]. max_length (`Optional[int]`, *optional*, defaults to `None`): - Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want - to use the default data collator. + Maximum combined length of prompt and completion; longer sequences are truncated left. max_prompt_length (`Optional[int]`, *optional*, defaults to `None`): - Maximum length of the prompt. This argument is required if you want to use the default data collator. + Maximum length of the prompt; longer prompts are truncated based on `truncation_mode`. max_completion_length (`Optional[int]`, *optional*, defaults to `None`): - Maximum length of the completion. This argument is required if you want to use the default data collator - and your model is an encoder-decoder. + Maximum length of the completion; longer completions are truncated right. beta (`float`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. @@ -58,7 +56,6 @@ class KTOConfig(TrainingArguments): Padding value to use. If `None`, the padding value of the tokenizer is used. truncation_mode (`str`, *optional*, defaults to `"keep_end"`): Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. - This argument is required if you want to use the default data collator. generate_during_eval (`bool`, *optional*, defaults to `False`): If `True`, generates and logs completions from both the model and the reference model to W&B during evaluation. diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index b856af6672..1b73a89805 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -473,6 +473,7 @@ def _prepare_dataset( "processing_class": processing_class, "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, + "truncation_mode": args.truncation_mode, # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) "add_special_tokens": False, }, @@ -482,7 +483,9 @@ def _prepare_dataset( return dataset @staticmethod - def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): + def tokenize_row( + features, processing_class, max_prompt_length, max_completion_length, truncation_mode, add_special_tokens + ): """ Tokenize a row of the dataset. @@ -495,6 +498,8 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. max_completion_length (`int` or `None`): Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + truncation_mode (`str`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. add_special_tokens (`bool`): Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the @@ -527,7 +532,12 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l # Truncate prompt and completion sequences if max_prompt_length is not None: - prompt_input_ids = prompt_input_ids[-max_prompt_length:] + if truncation_mode == "keep_end": + prompt_input_ids = prompt_input_ids[-max_prompt_length:] + elif truncation_mode == "keep_start": + prompt_input_ids = prompt_input_ids[:max_prompt_length] + else: + raise ValueError(f"Invalid truncation mode: {truncation_mode}") if max_completion_length is not None: completion_input_ids = completion_input_ids[:max_completion_length] @@ -610,10 +620,10 @@ def forward( labels[i] = torch.roll(labels[i], shifts=-first_one_idx) # Truncate right - if self.args.max_length is not None: - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - labels = labels[:, : self.args.max_length] + if self.max_length is not None: + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + labels = labels[:, -self.max_length :] logits = model(input_ids=input_ids, attention_mask=attention_mask).logits logits = logits[:, :-1] @@ -650,10 +660,10 @@ def forward( kl_labels[i] = torch.roll(kl_labels[i], shifts=-first_one_idx) # Truncate right - if self.args.max_length is not None: - kl_input_ids = kl_input_ids[:, : self.args.max_length] - kl_attention_mask = kl_attention_mask[:, : self.args.max_length] - kl_labels = kl_labels[:, : self.args.max_length] + if self.max_length is not None: + kl_input_ids = kl_input_ids[:, -self.max_length :] + kl_attention_mask = kl_attention_mask[:, -self.max_length :] + kl_labels = kl_labels[:, -self.max_length :] with torch.no_grad(): kl_logits = model(input_ids=kl_input_ids, attention_mask=kl_attention_mask).logits