Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Liger] add native liger-kernel orpo loss #2482

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
"diffusers": ["diffusers>=0.18.0"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"liger": ["liger-kernel>=0.5.1; sys_platform != 'win32'"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
Expand Down
35 changes: 34 additions & 1 deletion tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_liger_kernel, require_peft

from trl import ORPOConfig, ORPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
Expand Down Expand Up @@ -148,3 +148,36 @@ def test_orpo_trainer_with_lora(self, config_name):
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_liger_kernel
def test_orpo_trainer_with_liger(self):
"""Test ORPO trainer with Liger loss enabled."""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(
kashif marked this conversation as resolved.
Show resolved Hide resolved
output_dir=tmp_dir,
report_to="none",
use_liger_loss=True, # Enable Liger loss
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = ORPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
7 changes: 7 additions & 0 deletions trl/trainer/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class ORPOConfig(TrainingArguments):
string.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from the
model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
"""

learning_rate: float = 1e-6
Expand All @@ -76,3 +81,5 @@ class ORPOConfig(TrainingArguments):
is_encoder_decoder: Optional[bool] = None
model_init_kwargs: Optional[dict[str, Any]] = None
dataset_num_proc: Optional[int] = None
use_liger_loss: bool = False
base_model_attribute_name: str = "model"
209 changes: 148 additions & 61 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
Expand All @@ -71,7 +71,6 @@
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training

kashif marked this conversation as resolved.
Show resolved Hide resolved

if is_wandb_available():
import wandb

Expand All @@ -81,6 +80,9 @@
if is_torch_xla_available():
import torch_xla.core.xla_model as xm

if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss


class ORPOTrainer(Trainer):
r"""
Expand Down Expand Up @@ -361,6 +363,15 @@ def make_inputs_require_grad(module, input, output):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

# Import Liger loss if enabled
kashif marked this conversation as resolved.
Show resolved Hide resolved
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down Expand Up @@ -756,53 +767,112 @@ def concatenated_forward(
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
labels,
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
# orpo nll target is with respect to the concatenated prompt + completionlabels
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)

if self.args.use_liger_loss:
if self.is_encoder_decoder:
# 1. Get encoder outputs
encoder_outputs = model.get_encoder()(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
return_dict=True,
)
# 2. Get decoder outputs
outputs = model.get_decoder()(
input_ids=model_kwargs["decoder_input_ids"],
encoder_hidden_states=encoder_outputs.last_hidden_state,
use_cache=False,
)
else:
# skip the lm head and get the last hidden state
if hasattr(model, "get_decoder"):
base_model = model.get_decoder()
else:
base_model = getattr(model, self.args.base_model_attribute_name)
outputs = base_model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
lm_head = model.get_output_embeddings()

# return the final loss and aux_outputs tuple
loss, aux_outputs = self.orpo_loss_fn(
lm_head.weight,
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
nll_target=labels[:, 1:] if not self.is_encoder_decoder else labels,
)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
if self.aux_loss_enabled:
loss += self.aux_loss_coef * outputs.aux_loss

if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1, :]
rejected_logits = all_logits[len_chosen:, :-1, :]
return loss, aux_outputs
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
output_hidden_states=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1, :]
rejected_logits = all_logits[len_chosen:, :-1, :]
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
outputs.aux_loss,
)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -814,21 +884,41 @@ def get_batch_loss_metrics(
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
if self.args.use_liger_loss:
# full ORPO loss and aux outputs
(
loss,
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
chosen_rewards,
rejected_rewards,
log_odds_ratio,
log_odds_chosen,
),
) = forward_output
else:
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

reward_accuracies = (chosen_rewards > rejected_rewards).float()

Expand All @@ -848,8 +938,6 @@ def get_batch_loss_metrics(
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

return loss, metrics

Expand All @@ -861,7 +949,6 @@ def compute_loss(
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

Expand Down
Loading