Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix math extraction #503

Merged
merged 13 commits into from
Jan 18, 2025
38 changes: 20 additions & 18 deletions src/lighteval/metrics/utils/extractive_match_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,42 +103,42 @@ def lazy_expr_regex(expr_config: ExprExtractionConfig, language: Language) -> li
operators_re = "".join(operators)
all_expr_chars = r"[\d\.\s" + operators_re + r"]"
# Expression should have at minimum at least one operator and must start with a digit
expr_re = rf"-?\(?-?\d{all_expr_chars}*[{operators_re}]{all_expr_chars}+\)?"
expr_re = rf"(?P<expr>-?\(?-?\d{all_expr_chars}*[{operators_re}]{all_expr_chars}+\)?)"

# Punctuation regexes
full_stop_re = rf"[{re.escape(translation_literal.full_stop)}\.]"
comma_re = rf"[{re.escape(translation_literal.comma)}\,]"
colon_re = rf"[{re.escape(translation_literal.colon)}\:]"
space_re = rf"(?:\s|{re.escape(translation_literal.sentence_space)})"

currency_units = re.escape("$€£¥₹₽₪₩₫฿₡₢₣₤₥₦₧₨₩₪₫₭₮₯₰₱₲₳₴₵₶₷₸₹₺₻₼₽₾₿")
expr_prefix_re = rf"(?:^|{space_re}|\=)(?:\*\*)?"
expr_suffix_re = rf"(?:\*\*)?(?:{full_stop_re}|{comma_re}|{colon_re}|{space_re}|\)|\$|$)"

expr = f"(?P<expr>{expr_re}|{number_re})"
full_expr = rf"(?:{expr_prefix_re}{expr}{expr_suffix_re})"
# Expressions must be prefixed and suffixed while, digits don't need suffix and can have currency units preceeded, this is to ensure
# That we can extract stuff like $100 or 100m2, while we don't extract XDY2K as 2
expr_with_anchors = rf"(?:{expr_prefix_re}{expr_re}{expr_suffix_re})"
number_with_anchors = rf"(?:{expr_prefix_re}[{currency_units}]?{number_re})"
expr_or_number = rf"(?:{expr_with_anchors}|{number_with_anchors})"
regexes: list[tuple[str, int]] = []

# Ideally we would have translation of such concept in each language
if language == Language.ENGLISH:
final_answer_prefixed_re = rf"(?i:final answer is)\:?\s*{full_expr}\.?\s?I hope"
final_answer_prefixed_just_is = rf"(?i:final answer.{{0,100}}?)\s+is\:?{full_expr}"
final_answer_prefixed_re = rf"(?i:final answer is)\:?\s*{expr_or_number}\.?\s?I hope"
final_answer_prefixed_just_is = rf"(?i:final answer.{{0,100}}?)\s+is\:?{expr_or_number}"
regexes.append((final_answer_prefixed_re, 0))
regexes.append((final_answer_prefixed_just_is, 50))

answer_prefix_re = rf"(?i:{translation_literal.answer})"

# Match after the last equals with answer word - require the number pattern,
equals_re_colon = rf"{answer_prefix_re}{colon_re}(?:.{{0,100}}=\s*|.{{0,50}}?){full_expr}(?!\s*=)"
equals_re = rf"{answer_prefix_re}(?:.{{0,100}}=\s*|.{{0,50}}?){full_expr}(?!\s*=)"
equals_re_colon = rf"{answer_prefix_re}{colon_re}(?:.{{0,100}}=\s*|.{{0,50}}?){expr_or_number}(?!\s*=)"
equals_re = rf"{answer_prefix_re}(?:.{{0,100}}=\s*|.{{0,50}}?){expr_or_number}(?!\s*=)"
regexes.extend([(equals_re_colon, 100), (equals_re, 200)])

if expr_config.try_extract_without_anchor:
# If everything fails, try to match plain expr/number
regexes.append((f"({expr_prefix_re})(?P<expr>{expr_re})({expr_suffix_re})", 300))
regexes.append((f"({expr_prefix_re})(?P<expr>{number_re})({expr_suffix_re})", 300))

# Worst case just ignore any prefix/suffix, e.g 1$ wouldn't be extracted otherwise
regexes.append((f"((?P<expr>{number_re}))", 350))
regexes.append((expr_with_anchors, 300))
regexes.append((number_with_anchors, 300))

return [(re.compile(pattern), priority) for pattern, priority in regexes]

Expand Down Expand Up @@ -299,7 +299,7 @@ def extract_expr(match: re.Match) -> tuple[str | sympy.Expr | None, str]:
# First combine the number
groups = match.groupdict()
# Expr group will always exist because every regex has it
expr = groups["expr"]
expr = groups.get("expr", "")
integer = next((val for name, val in groups.items() if name.startswith("integer") and val), "")
decimal = next((val for name, val in groups.items() if name.startswith("decimal") and val), "")

Expand All @@ -321,10 +321,12 @@ def extract_expr(match: re.Match) -> tuple[str | sympy.Expr | None, str]:

# Otherwise just return the expression
# Remove new lines and spaces
try:
return parse_expr_with_timeout(expr.replace("\n", " ").replace("^", "**")), expr
except: # noqa: E722
return None, expr
if expr:
try:
return parse_expr_with_timeout(expr.replace("\n", " ").replace("^", "**")), expr
except: # noqa: E722
pass
return None, expr


def convert_to_pct(number: Number):
Expand Down
9 changes: 9 additions & 0 deletions tests/metrics/test_extractive_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,18 @@ def test_multilingual_extraction_math_latex_numbers(gold, pred, language, expect
("0.4", ".4", 1),
# Test decimals
("1000.99", "1,000.99", 1),
("1000.99", "1,000.99", 1),
# Test with units like $
("1000.99", "$1,000.99", 1),
("1000.99", "1,000.99$", 1),
# Test with currency units
("1000.99", "the number is not 10 which is 1,000.99€", 1),
("1000.99", "the number is not 10 which is 1,000.99€", 1),
# Test m2
("1000.99", "so the number is 10 which is 1,000.99m²", 1),
("1000.99", "not it's not 10 it's 1,000.99m²", 1),
# Test correct extraction of not correct answer
("2", "AZYUK2A", 0),
],
)
def test_number_extraction(gold, pred, expected):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from pathlib import Path
from types import ModuleType
from typing import Optional, Union
from unittest.mock import patch

from anyio import Path
from transformers import AutoTokenizer

from lighteval.logging.evaluation_tracker import EvaluationTracker
Expand Down
Loading