Skip to content

Commit

Permalink
Merge branch 'main' into Document-Custom-Model-Files
Browse files Browse the repository at this point in the history
  • Loading branch information
ParagEkbote authored Jan 17, 2025
2 parents 5d225b3 + 59624c8 commit 262b1cd
Show file tree
Hide file tree
Showing 9 changed files with 2,053 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

---

**Documentation**: <a href="https://github.com/huggingface/lighteval/wiki" target="_blank">Lighteval's Wiki</a>
**Documentation**: <a href="https://huggingface.co/docs/lighteval/index" target="_blank">Lighteval's Wiki</a>

---

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests,multilingual]"]
dev = ["lighteval[accelerate,quality,tests,multilingual,math]"]
docs = ["hf-doc-builder", "watchdog"]
extended_tasks = [
"langdetect", # ifeval
Expand All @@ -109,6 +109,7 @@ multilingual = [
"jieba", # for chinese tokenizer
"pyvi", # for vietnamese tokenizer
]
math = ["latex2sympy2_extended>=0.9.0"]

[project.urls]
Homepage = "https://github.com/huggingface/lighteval"
Expand Down
108 changes: 107 additions & 1 deletion src/lighteval/metrics/dynamic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable, Literal
import logging
from typing import Callable, Literal, Sequence

import numpy as np

Expand All @@ -37,8 +38,22 @@
LogProbTokenNorm,
get_multilingual_normalizer,
)
from lighteval.metrics.utils.extractive_match_utils import ( # noqa: F401
ExprExtractionConfig,
ExtractionTarget,
IndicesExtractionConfig,
LatexExtractionConfig,
extract_target_from_pred,
get_extraction_regexes,
)
from lighteval.metrics.utils.math_comparison import compare_gold_target
from lighteval.metrics.utils.metric_utils import MetricCategory, MetricUseCase, SampleLevelMetric
from lighteval.tasks.requests import Doc
from lighteval.utils.language import Language
from lighteval.utils.timeout import timeout


logger = logging.getLogger(__name__)


def loglikelihood_acc_metric(normalization: LogProbNormalization | None = None) -> SampleLevelMetric:
Expand Down Expand Up @@ -168,3 +183,94 @@ def multilingual_quasi_exact_match_metric(
corpus_level_fn=np.mean,
higher_is_better=True,
)


def multilingual_extractive_match_metric(
language: Language = Language.ENGLISH,
gold_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
pred_extraction_target: Sequence[ExtractionTarget] = (ExprExtractionConfig(),),
aggregation_function: Callable[[list[float]], float] = max,
fallback_mode: Literal["no_fallback", "first_match"] = "first_match",
precision: int = 6,
) -> SampleLevelMetric:
"""Creates a language-aware extractive match metric that extracts answers from the model's output.
Known issues:
- If the task is to simplify an expression, the metric might overestimate the accuracy. This is because if the model doesn't output any anchor for the extraction (e.g final answer is..),
it's possible that the the extracted prediction will be the expression to simplify. Because we do simplifications ourselves, it can thus happen that sympy will correctly simplify the expression,
thus it will match gold, despite model not doing anything. PRs to fix this are welcome.
- There is currently no StringExtractionConfig, so if the gold is \boxed{\text{Friday}} and model outputs Friday it will not match, because nothing will be extracted.
Args:
language: Language
The language of the samples.
gold_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for gold answers. Defaults to extracting simple math expressions.
pred_extraction_target: Sequence[ExtractionTarget]
Extraction targets to use for predictions. Defaults to extracting simple math expressions.
aggregation_function: Callable[[list[float]], float]
Function to aggregate scores when multiple golds/predictions are present. Defaults to max.
fallback_mode: Literal["no_fallback", "first_match"]
How to perform extraction. Defaults to "first_match".
- "no_fallback": Only use first successfully parsed matches
- "first_match": Use the first successfully parsed match + first match irregardless the parsing success
precision: int
Number of decimal places to use when comparing numerical values. Defaults to 6.
Returns:
A sample level metric that extracts and compares mathematical expressions.
"""

@timeout(2)
def add_to_specifics_with_timeout(
formatted_doc: Doc, extracted_predictions: list[list[str]], extracted_golds: list[list[str]]
) -> None:
if formatted_doc.specific is None:
formatted_doc.specific = {}

formatted_doc.specific["extracted_predictions"] = [
str(pred) for preds in extracted_predictions for pred in preds
]
formatted_doc.specific["extracted_golds"] = [str(gold) for golds in extracted_golds for gold in golds]

def sample_level_fn(golds: list[str], predictions: list[str], formatted_doc: Doc) -> float:
gold_extraction_regexes = get_extraction_regexes(formatted_doc, gold_extraction_target, language)
pred_extraction_regexes = get_extraction_regexes(formatted_doc, pred_extraction_target, language)

extracted_predictions = [
extract_target_from_pred(pred, pred_extraction_regexes, fallback_mode) for pred in predictions
]
extracted_golds = [extract_target_from_pred(gold, gold_extraction_regexes, fallback_mode) for gold in golds]

# Assert on empty gold and warn on empty pred
if any(len(g) == 0 for g in extracted_golds):
raise ValueError(f"No gold targets found for at least one gold. Gold: {golds}, Pred: {predictions}")

if all(len(p) == 0 for p in extracted_predictions):
logger.warning(
f"We did not manage to extract a prediction in the correct format. Gold: {golds}, Pred: {predictions}"
)

# We have to use timeout because the sypmy to str conversion can be very slow
try:
add_to_specifics_with_timeout(formatted_doc, extracted_predictions, extracted_golds)
except: # noqa: E722
logger.warning("Timeout when adding extracted predictions and golds to specific")

return aggregation_function(
[
(1.0 if any(compare_gold_target(gold, pred, precision) for gold in extracted_golds) else 0.0)
for pred in extracted_predictions
]
)

return SampleLevelMetric(
metric_name="extractive_match",
sample_level_fn=sample_level_fn,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
corpus_level_fn=np.mean,
higher_is_better=True,
)
Loading

0 comments on commit 262b1cd

Please sign in to comment.