-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle non-capturing groups in regex transforms.
- Loading branch information
Showing
3 changed files
with
149 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from collections.abc import Iterable | ||
import itertools | ||
import re | ||
from typing import Any | ||
|
||
|
||
def regex_to_glob(regex: str | list[str]) -> list[str]: | ||
"""Converts a regular expression to several glob patterns. | ||
The function applies several transformations to achieve this: | ||
- Expand all non-capturing groups to possibly several expression. For example, | ||
for Hugging Face croissants, "default/(?:partial-)?train/.+parquet$" would become | ||
["default/train/.+parquet$", "default/partial-train/.+parquet$"] to match both with | ||
and without the `partial-` prefix. | ||
- Convert to a glob pattern using some heuristics. | ||
""" | ||
if isinstance(regex, str): | ||
results = [regex] | ||
for fn in [_expand_non_capturing_groups, _regex_to_glob_for_str]: | ||
results = list(itertools.chain.from_iterable(fn(result) for result in results)) | ||
return results | ||
|
||
|
||
def _expand_non_capturing_groups(regex: str) -> Iterable[str]: | ||
if "(?:" not in regex: | ||
# There is no non-capturing group: | ||
return [regex] | ||
|
||
# Find all capturing groups: | ||
pattern = r"\(\?:.*?\)\?|[^()]+" | ||
strings = re.findall(pattern, regex) | ||
if not strings: | ||
raise ValueError("the string should not be empty") | ||
subregex = "".join(strings[1:]) | ||
# Recursively construct the results for the sub-regex: | ||
results = _expand_non_capturing_groups(subregex) | ||
string = strings[0] | ||
if string.startswith("(?:") and string.endswith(")?"): | ||
# Append the inside of the non-capturing group: | ||
string = string[len("(?:") : -len(")?")] | ||
return itertools.chain( | ||
[f"{string}{result}" for result in results], | ||
results, | ||
) | ||
else: | ||
# Append the string itself to each result: | ||
return [f"{string}{result}" for result in results] | ||
|
||
|
||
def _regex_to_glob_for_str(regex: str) -> Iterable[str]: | ||
"""Converts a regular expression to a glob pattern by unescaping regex syntax. | ||
Warning: this is based on manual heuristics to convert a regular expression to a | ||
glob expression. | ||
""" | ||
# Remove starting ^ | ||
regex = re.sub(r"^\^", "", regex) | ||
# Remove trailing $ | ||
regex = re.sub(r"\$$", "", regex) | ||
# Interpret \. as . | ||
regex = re.sub(r"\\\.", ".", regex) | ||
# Interpret .* as * | ||
regex = re.sub(r"\.\*", "*", regex) | ||
# Interpret .+ as * | ||
regex = re.sub(r"\.\+", "*", regex) | ||
return [regex] | ||
|
||
|
||
def capture_one_capturing_group(str_regex: str, value: Any) -> str: | ||
"""Captures the one and only capturing group, but ignoring non-capturing gorups.""" | ||
# Non-capturing groups have the form (?:a_non_capturing_group) | ||
capturing_groups = re.compile(r"\((?!\?\:).*?\)") | ||
groups = capturing_groups.findall(str_regex) | ||
if len(groups) == 1: | ||
matches = re.match(groups[0], value) | ||
if not matches: | ||
raise ValueError( | ||
"The replace value doesn't respect the expected capturing group." | ||
f" Expected: {groups[0]}. Got: {value}" | ||
) | ||
else: | ||
raise ValueError( | ||
"A transform regex should have exactly 1 capturing group in" | ||
f" the transform regex. Got: '{str_regex}'" | ||
) | ||
return capturing_groups.sub(re.escape(value), str_regex) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
"""regex_test module.""" | ||
|
||
import pytest | ||
|
||
from mlcroissant._src.core import regex as regex_lib | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["regex", "output"], | ||
[ | ||
[ | ||
"(?:baz)?xxx(?:foo)?yyy(?:bar)?", | ||
[ | ||
"bazxxxfooyyybar", | ||
"bazxxxfooyyy", | ||
"bazxxxyyybar", | ||
"bazxxxyyy", | ||
"xxxfooyyybar", | ||
"xxxfooyyy", | ||
"xxxyyybar", | ||
"xxxyyy", | ||
], | ||
], | ||
[ | ||
"^.+/train/.*\.parquet$", # From a valid regex... | ||
[ | ||
"*/train/*.parquet", # ...to a valid glob pattern. | ||
], | ||
], | ||
], | ||
) | ||
def test_regex_to_glob(regex: str, output: list[str]): | ||
assert regex_lib.regex_to_glob(regex) == output | ||
|
||
|
||
def test_capture_one_capturing_group(): | ||
# The value does not match: | ||
with pytest.raises(ValueError): | ||
regex_lib.capture_one_capturing_group( | ||
"default/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING" | ||
) | ||
|
||
# Too many capturing groups: | ||
with pytest.raises(ValueError): | ||
regex_lib.capture_one_capturing_group( | ||
"(default)/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING" | ||
) | ||
|
||
assert ( | ||
regex_lib.capture_one_capturing_group( | ||
"default/(?:partial-)?(train|test)/.+parquet$", "train" | ||
) | ||
== "default/(?:partial-)?train/.+parquet$" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters