Skip to content

Commit

Permalink
feat: add asian language support to CorpusLevelTranslationMetric (#479)
Browse files Browse the repository at this point in the history
* feat: add asian language support to CorpusLevelTranslationMetric

* fix: ci
  • Loading branch information
ryan-minato authored Jan 20, 2025
1 parent fee2ec3 commit be18ae5
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/lighteval/metrics/metrics_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""
import logging
import math
from typing import Literal

import numpy as np
import sacrebleu
Expand Down Expand Up @@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]):


class CorpusLevelTranslationMetric:
def __init__(self, metric_type: str):
def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
"""Stores the relevant parameters for a corpus level translation metric.
Args:
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
"""
if metric_type == "bleu":
self.metric = sacrebleu.corpus_bleu
elif metric_type == "chrf":
self.metric = sacrebleu.corpus_chrf
elif metric_type == "ter":
self.metric = sacrebleu.corpus_ter
self.metric_type = metric_type
self.lang = lang

def get_metric(self):
if self.metric_type == "bleu":
return sacrebleu.BLEU(trg_lang=self.lang)
elif self.metric_type == "chrf":
return sacrebleu.CHRF()
elif self.metric_type == "ter":
return sacrebleu.TER(asian_support=True if self.lang != "" else False)
else:
raise ValueError(f"Unknown corpus level translation metric type : {metric_type}")
raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}")

def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
metric = self.get_metric()
golds = [i.golds for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
if len(pred) > 1:
logger.info(
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
)
preds.append(pred[0])
return float(self.metric(hypotheses=preds, references=golds).score)
return float(metric.corpus_score(hypotheses=preds, references=golds).score)


class CorpusLevelPerplexityMetric:
Expand Down

0 comments on commit be18ae5

Please sign in to comment.