Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into add_swiss_legal_evals
Browse files Browse the repository at this point in the history
  • Loading branch information
rolshoven authored Jan 20, 2025
2 parents cb6bfb4 + be18ae5 commit 306ee76
Show file tree
Hide file tree
Showing 14 changed files with 2,124 additions and 32 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

---

**Documentation**: <a href="https://github.com/huggingface/lighteval/wiki" target="_blank">Lighteval's Wiki</a>
**Documentation**: <a href="https://huggingface.co/docs/lighteval/index" target="_blank">Lighteval's Wiki</a>

---

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
docs = ["hf-doc-builder", "watchdog"]
extended_tasks = [
"langdetect", # ifeval
Expand All @@ -111,6 +111,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.0"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
11 changes: 9 additions & 2 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@
from typing import Iterator, Tuple

import torch
from packaging import version
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler, T_co


if version.parse(torch.__version__) >= version.parse("2.5.0"):
from torch.utils.data.distributed import DistributedSampler, _T_co
else:
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.distributed import T_co as _T_co

from lighteval.tasks.requests import (
GreedyUntilRequest,
Expand Down Expand Up @@ -318,7 +325,7 @@ class GenDistributedSampler(DistributedSampler):
as our samples are sorted by length.
"""

def __iter__(self) -> Iterator[T_co]:
def __iter__(self) -> Iterator[_T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
Expand Down
108 changes: 107 additions & 1 deletion src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable, Literal
import logging
from typing import Callable, Literal, Sequence

import numpy as np

Expand All @@ -37,8 +38,22 @@
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
ExprExtractionConfig,
ExtractionTarget,
IndicesExtractionConfig,
LatexExtractionConfig,
extract_target_from_pred,
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparison import compare_gold_target
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout


logger = logging.getLogger(__name__)


def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
Expand Down Expand Up @@ -168,3 +183,94 @@ def multilingual_quasi_exact_match_metric(
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_extractive_match_metric(
language: Language = Language.ENGLISH,
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
precision: int = 6,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Known issues:
- If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..),
it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression,
thus it will match gold, despite model not doing anything. PRs to fix this are welcome.
- There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted.
Args:
language: Language
The language of the samples.
gold_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for gold answers. Defaults to extracting simple math expressions.
pred_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for predictions. Defaults to extracting simple math expressions.
aggregation_function: Callable[[list[float]], float]
Function to aggregate scores when multiple golds/predictions are present. Defaults to max.
fallback_mode: Literal["no_fallback", "first_match"]
How to perform extraction. Defaults to "first_match".
- "no_fallback": Only use first successfully parsed matches
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
Returns:
A sample level metric that extracts and compares mathematical expressions.
"""

@timeout(2)
def add_to_specifics_with_timeout(
formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]]
) -> None:
if formatted_doc.specific is None:
formatted_doc.specific = {}

formatted_doc.specific["extracted_predictions"] = [
str(pred) for preds in extracted_predictions for pred in preds
]
formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds]

def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float:
gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language)
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
]
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")

if all(len(p) == 0 for p in extracted_predictions):
logger.warning(
f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}"
)

# We have to use timeout because the sypmy to str conversion can be very slow
try:
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
except: # noqa: E722
logger.warning("Timeout when adding extracted predictions and golds to specific")

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
for pred in extracted_predictions
]
)

return SampleLevelMetric(
metric_name="extractive_match",
sample_level_fn=sample_level_fn,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
28 changes: 19 additions & 9 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,22 +193,32 @@ def __call_litellm(self, prompts):
import litellm

def __call_api(prompt):
error_message = "ERROR: Failed to get response from the API."
for _ in range(self.API_MAX_RETRY):
try:
response = litellm.completion(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
n=1,
caching=True,
)
kwargs = {
"model": self.model,
"messages": prompt,
"response_format": {"type": "text"},
"max_tokens": 512,
"n": 1,
"caching": True,
}
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
if not text or text == error_message:
kwargs["caching"] = False
response = litellm.completion(**kwargs)
text = response.choices[0].message.content
if not text or text == error_message:
# Just return an error response if the second attempt fails too
logger.error(f"Failed to get response from the API for prompt: {prompt}")
return error_message
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
raise Exception("Failed to get response from the API")
return error_message

results = []
with ThreadPoolExecutor(100) as executor:
Expand Down
26 changes: 16 additions & 10 deletions src/lighteval/metrics/metrics_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""
import logging
import math
from typing import Literal

import numpy as np
import sacrebleu
Expand Down Expand Up @@ -89,33 +90,38 @@ def compute(self, items: list[LogprobCorpusMetricInput]):


class CorpusLevelTranslationMetric:
def __init__(self, metric_type: str):
def __init__(self, metric_type: str, lang: Literal["zh", "ja", "ko", ""] = ""):
"""Stores the relevant parameters for a corpus level translation metric.
Args:
metric_type (str): Can be any of bleu, chrf, or ter depending on the metric to use.
"""
if metric_type == "bleu":
self.metric = sacrebleu.corpus_bleu
elif metric_type == "chrf":
self.metric = sacrebleu.corpus_chrf
elif metric_type == "ter":
self.metric = sacrebleu.corpus_ter
self.metric_type = metric_type
self.lang = lang

def get_metric(self):
if self.metric_type == "bleu":
return sacrebleu.BLEU(trg_lang=self.lang)
elif self.metric_type == "chrf":
return sacrebleu.CHRF()
elif self.metric_type == "ter":
return sacrebleu.TER(asian_support=True if self.lang != "" else False)
else:
raise ValueError(f"Unknown corpus level translation metric type : {metric_type}")
raise ValueError(f"Unknown corpus level translation metric type : {self.metric_type}")

def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
metric = self.get_metric()
golds = [i.golds for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
if len(pred) > 1:
logger.info(
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{metric.__name__})."
)
preds.append(pred[0])
return float(self.metric(hypotheses=preds, references=golds).score)
return float(metric.corpus_score(hypotheses=preds, references=golds).score)


class CorpusLevelPerplexityMetric:
Expand Down
Loading

0 comments on commit 306ee76

Please sign in to comment.