Skip to content

Commit

Permalink
Styling
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 28, 2024
1 parent a984c52 commit 48b0bb6
Showing 1 changed file with 2 additions and 50 deletions.
52 changes: 2 additions & 50 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import datasets
import numpy as np
Expand Down Expand Up @@ -2193,7 +2193,7 @@ def log1p(x):
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
log1p(-torch.exp(policy_chosen_logps)) - log1p(-torch.exp(policy_rejected_logps))
)

sig_ratio = torch.nn.functional.sigmoid(log_odds)
ratio = torch.log(sig_ratio)
losses = self.beta * ratio
Expand All @@ -2202,51 +2202,3 @@ def log1p(x):
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()

return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)

def _get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()

reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss

return loss, metrics

0 comments on commit 48b0bb6

Please sign in to comment.