Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTO refactor #2507

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
46 changes: 46 additions & 0 deletions tests/test_collators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from trl.trainer.kto_trainer import DataCollatorForUnpairedPreference


class TestDataCollatorForUnpairedPreference(unittest.TestCase):
def setUp(self):
self.collator = DataCollatorForUnpairedPreference(pad_token_id=0)

def assertTensorEqual(self, tensor1, tensor2):
self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}")

def test_padding_behavior(self):
examples = [
{"prompt_input_ids": [1, 2, 3], "completion_input_ids": [4, 5], "label": True},
{"prompt_input_ids": [6, 7], "completion_input_ids": [8, 9, 10], "label": False},
]
output = self.collator.torch_call(examples)

expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 6, 7]])
expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
expected_completion_input_ids = torch.tensor([[4, 5, 0], [8, 9, 10]])
expected_completion_attention_mask = torch.tensor([[1, 1, 0], [1, 1, 1]])
expected_labels = torch.tensor([True, False])

self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids)
self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask)
self.assertTensorEqual(output["completion_input_ids"], expected_completion_input_ids)
self.assertTensorEqual(output["completion_attention_mask"], expected_completion_attention_mask)
self.assertTensorEqual(output["labels"], expected_labels)
86 changes: 0 additions & 86 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from transformers.testing_utils import require_peft

from trl import KTOConfig, KTOTrainer
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize

from .testing_utils import require_no_wandb

Expand Down Expand Up @@ -122,91 +121,6 @@ def test_kto_trainer_with_ref_model_is_model(self):
train_dataset=dummy_dataset["train"],
)

def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

trainer = KTOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

train_dataset = dummy_dataset["train"]
tokenized_dataset = train_dataset.map(
_tokenize,
fn_kwargs={"tokenizer": trainer.tokenizer},
batched=True,
batch_size=2,
)
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1])

# Test corruption of (prompt, completion) pairs for KL dataset
for batch_size in [2, 3]:
tokenized_kl_dataset = tokenized_dataset.map(_get_kl_dataset, batched=True, batch_size=batch_size)

# Verify that the "answer_input_ids" have been modified, meaning the new "answer_input_ids" differ
# from the original ones. However, when the length of the dataset modulo batch_size equals 1,
# the last batch remains unaltered. This is a rare scenario that does not impact the training
# process, so we exclude it from testing by iterating only up to len - 1.
for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1):
self.assertListEqual(
tokenized_dataset["prompt_input_ids"][i],
tokenized_kl_dataset["prompt_input_ids"][i],
)
self.assertListEqual(
tokenized_dataset["prompt_attention_mask"][i],
tokenized_kl_dataset["prompt_attention_mask"][i],
)
self.assertNotEqual(
tokenized_dataset["answer_input_ids"][i],
tokenized_kl_dataset["answer_input_ids"][i],
)

fn_kwargs = {
"prefix": "",
"is_encoder_decoder": trainer.is_encoder_decoder,
"tokenizer": trainer.tokenizer,
"max_length": trainer.max_length,
"truncation_mode": trainer.truncation_mode,
"label_pad_token_id": trainer.label_pad_token_id,
"max_prompt_length": trainer.max_prompt_length,
}
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2)
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
self.assertListEqual(
processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]
)
self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
self.assertListEqual(
processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]
)

def test_kto_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
Expand Down
11 changes: 4 additions & 7 deletions trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ class KTOConfig(TrainingArguments):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
Maximum combined length of prompt and completion; longer sequences are truncated left.
max_prompt_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the prompt. This argument is required if you want to use the default data collator.
Maximum length of the prompt; longer prompts are truncated based on `truncation_mode`.
max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the completion. This argument is required if you want to use the default data collator
and your model is an encoder-decoder.
Maximum length of the completion; longer completions are truncated right.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model.
Expand All @@ -58,7 +56,6 @@ class KTOConfig(TrainingArguments):
Padding value to use. If `None`, the padding value of the tokenizer is used.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
This argument is required if you want to use the default data collator.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B during
evaluation.
Expand All @@ -77,7 +74,7 @@ class KTOConfig(TrainingArguments):
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
Whether to disable dropout in the model and reference model.
"""

learning_rate: float = 1e-6
Expand Down
Loading
Loading