From 48b0bb6f615d53bb1d3a3301d89abcc13869e21e Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 28 Oct 2024 18:29:20 +0100 Subject: [PATCH] Styling --- optimum/neuron/trainers.py | 52 ++------------------------------------ 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 05d967436..375e824d7 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -25,7 +25,7 @@ import warnings from collections import defaultdict from functools import wraps -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import datasets import numpy as np @@ -2193,7 +2193,7 @@ def log1p(x): log_odds = (policy_chosen_logps - policy_rejected_logps) - ( log1p(-torch.exp(policy_chosen_logps)) - log1p(-torch.exp(policy_rejected_logps)) ) - + sig_ratio = torch.nn.functional.sigmoid(log_odds) ratio = torch.log(sig_ratio) losses = self.beta * ratio @@ -2202,51 +2202,3 @@ def log1p(x): rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) - - def _get_batch_loss_metrics( - self, - model, - batch: Dict[str, Union[List, torch.LongTensor]], - train_eval: Literal["train", "eval"] = "train", - ): - metrics = {} - - forward_output = self.concatenated_forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_nll_loss, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( - policy_chosen_logps, policy_rejected_logps - ) - # full ORPO loss - loss = policy_nll_loss - losses.mean() - - reward_accuracies = (chosen_rewards > rejected_rewards).float() - - prefix = "eval_" if train_eval == "eval" else "" - metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() - metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() - metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() - metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() - metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean() - metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean() - metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean() - metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean() - metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean() - metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio - metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen - if is_torch_xla_available(): - xm.mark_step() # needed because .item() calls - for k, v in metrics.items(): - metrics[k] = v.item() - if self.aux_loss_enabled: - loss += getattr(model.config, "router_aux_loss_coef", 0.0) * aux_loss - - return loss, metrics