Skip to content

Commit

Permalink
Merge branch 'add_swiss_legal_evals' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus committed Feb 1, 2025
2 parents 76e867a + e7f9a09 commit 3fb93c9
Show file tree
Hide file tree
Showing 12 changed files with 2,218 additions and 83 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

---

**Documentation**: <a href="https://github.com/huggingface/lighteval/wiki" target="_blank">Lighteval's Wiki</a>
**Documentation**: <a href="https://huggingface.co/docs/lighteval/index" target="_blank">Lighteval's Wiki</a>

---

Expand Down
205 changes: 137 additions & 68 deletions community_tasks/swiss_legal_evals.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
docs = ["hf-doc-builder", "watchdog"]
extended_tasks = [
"langdetect", # ifeval
Expand All @@ -111,6 +111,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.0"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
108 changes: 107 additions & 1 deletion src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable, Literal
import logging
from typing import Callable, Literal, Sequence

import numpy as np

Expand All @@ -37,8 +38,22 @@
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
ExprExtractionConfig,
ExtractionTarget,
IndicesExtractionConfig,
LatexExtractionConfig,
extract_target_from_pred,
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparison import compare_gold_target
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout


logger = logging.getLogger(__name__)


def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
Expand Down Expand Up @@ -168,3 +183,94 @@ def multilingual_quasi_exact_match_metric(
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_extractive_match_metric(
language: Language = Language.ENGLISH,
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
precision: int = 6,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Known issues:
- If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..),
it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression,
thus it will match gold, despite model not doing anything. PRs to fix this are welcome.
- There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted.
Args:
language: Language
The language of the samples.
gold_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for gold answers. Defaults to extracting simple math expressions.
pred_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for predictions. Defaults to extracting simple math expressions.
aggregation_function: Callable[[list[float]], float]
Function to aggregate scores when multiple golds/predictions are present. Defaults to max.
fallback_mode: Literal["no_fallback", "first_match"]
How to perform extraction. Defaults to "first_match".
- "no_fallback": Only use first successfully parsed matches
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
Returns:
A sample level metric that extracts and compares mathematical expressions.
"""

@timeout(2)
def add_to_specifics_with_timeout(
formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]]
) -> None:
if formatted_doc.specific is None:
formatted_doc.specific = {}

formatted_doc.specific["extracted_predictions"] = [
str(pred) for preds in extracted_predictions for pred in preds
]
formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds]

def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float:
gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language)
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
]
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")

if all(len(p) == 0 for p in extracted_predictions):
logger.warning(
f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}"
)

# We have to use timeout because the sypmy to str conversion can be very slow
try:
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
except: # noqa: E722
logger.warning("Timeout when adding extracted predictions and golds to specific")

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
for pred in extracted_predictions
]
)

return SampleLevelMetric(
metric_name="extractive_match",
sample_level_fn=sample_level_fn,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
26 changes: 16 additions & 10 deletions src/lighteval/metrics/metrics_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""
import logging
import math
from typing import Literal

import numpy as np
import sacrebleu
Expand Down Expand Up @@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]):


class CorpusLevelTranslationMetric:
def __init__(self, metric_type: str):
def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
"""Stores the relevant parameters for a corpus level translation metric.
Args:
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
"""
if metric_type == "bleu":
self.metric = sacrebleu.corpus_bleu
elif metric_type == "chrf":
self.metric = sacrebleu.corpus_chrf
elif metric_type == "ter":
self.metric = sacrebleu.corpus_ter
self.metric_type = metric_type
self.lang = lang

def get_metric(self):
if self.metric_type == "bleu":
return sacrebleu.BLEU(trg_lang=self.lang)
elif self.metric_type == "chrf":
return sacrebleu.CHRF()
elif self.metric_type == "ter":
return sacrebleu.TER(asian_support=True if self.lang != "" else False)
else:
raise ValueError(f"Unknown corpus level translation metric type : {metric_type}")
raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}")

def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
metric = self.get_metric()
golds = [i.golds for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
if len(pred) > 1:
logger.info(
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
)
preds.append(pred[0])
return float(self.metric(hypotheses=preds, references=golds).score)
return float(metric.corpus_score(hypotheses=preds, references=golds).score)


class CorpusLevelPerplexityMetric:
Expand Down
Loading

0 comments on commit 3fb93c9

Please sign in to comment.