From be18ae59375ad37ae3d7997c6eaf7f5c1df3e979 Mon Sep 17 00:00:00 2001 From: Aoi <82735346+ryan-minato@users.noreply.github.com> Date: Mon, 20 Jan 2025 18:03:13 +0900 Subject: [PATCH] feat: add asian language support to CorpusLevelTranslationMetric (#479) * feat: add asian language support to CorpusLevelTranslationMetric * fix: ci --- src/lighteval/metrics/metrics_corpus.py | 26 +++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/lighteval/metrics/metrics_corpus.py b/src/lighteval/metrics/metrics_corpus.py index 03b1b2c5e..3c2de418f 100644 --- a/src/lighteval/metrics/metrics_corpus.py +++ b/src/lighteval/metrics/metrics_corpus.py @@ -26,6 +26,7 @@ """ import logging import math +from typing import Literal import numpy as np import sacrebleu @@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]): class CorpusLevelTranslationMetric: - def __init__(self, metric_type: str): + def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""): """Stores the relevant parameters for a corpus level translation metric. Args: metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use. """ - if metric_type == "bleu": - self.metric = sacrebleu.corpus_bleu - elif metric_type == "chrf": - self.metric = sacrebleu.corpus_chrf - elif metric_type == "ter": - self.metric = sacrebleu.corpus_ter + self.metric_type = metric_type + self.lang = lang + + def get_metric(self): + if self.metric_type == "bleu": + return sacrebleu.BLEU(trg_lang=self.lang) + elif self.metric_type == "chrf": + return sacrebleu.CHRF() + elif self.metric_type == "ter": + return sacrebleu.TER(asian_support=True if self.lang != "" else False) else: - raise ValueError(f"Unknown corpus level translation metric type : {metric_type}") + raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}") def compute(self, items: list[GenerativeCorpusMetricInput]) -> float: """Computes the metric score over all the corpus generated items, by using the sacrebleu implementation.""" + metric = self.get_metric() golds = [i.golds for i in items] preds = [] for i in items: pred = as_list(i.preds) if len(pred) > 1: logger.info( - f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})." + f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})." ) preds.append(pred[0]) - return float(self.metric(hypotheses=preds, references=golds).score) + return float(metric.corpus_score(hypotheses=preds, references=golds).score) class CorpusLevelPerplexityMetric: