diff --git a/python/mlcroissant/mlcroissant/_src/core/regex.py b/python/mlcroissant/mlcroissant/_src/core/regex.py new file mode 100644 index 000000000..8633e42a9 --- /dev/null +++ b/python/mlcroissant/mlcroissant/_src/core/regex.py @@ -0,0 +1,88 @@ +from collections.abc import Iterable +import itertools +import re +from typing import Any + + +def regex_to_glob(regex: str | list[str]) -> list[str]: + """Converts a regular expression to several glob patterns. + + The function applies several transformations to achieve this: + + - Expand all non-capturing groups to possibly several expression. For example, + for Hugging Face croissants, "default/(?:partial-)?train/.+parquet$" would become + ["default/train/.+parquet$", "default/partial-train/.+parquet$"] to match both with + and without the `partial-` prefix. + + - Convert to a glob pattern using some heuristics. + """ + if isinstance(regex, str): + results = [regex] + for fn in [_expand_non_capturing_groups, _regex_to_glob_for_str]: + results = list(itertools.chain.from_iterable(fn(result) for result in results)) + return results + + +def _expand_non_capturing_groups(regex: str) -> Iterable[str]: + if "(?:" not in regex: + # There is no non-capturing group: + return [regex] + + # Find all capturing groups: + pattern = r"\(\?:.*?\)\?|[^()]+" + strings = re.findall(pattern, regex) + if not strings: + raise ValueError("the string should not be empty") + subregex = "".join(strings[1:]) + # Recursively construct the results for the sub-regex: + results = _expand_non_capturing_groups(subregex) + string = strings[0] + if string.startswith("(?:") and string.endswith(")?"): + # Append the inside of the non-capturing group: + string = string[len("(?:") : -len(")?")] + return itertools.chain( + [f"{string}{result}" for result in results], + results, + ) + else: + # Append the string itself to each result: + return [f"{string}{result}" for result in results] + + +def _regex_to_glob_for_str(regex: str) -> Iterable[str]: + """Converts a regular expression to a glob pattern by unescaping regex syntax. + + Warning: this is based on manual heuristics to convert a regular expression to a + glob expression. + """ + # Remove starting ^ + regex = re.sub(r"^\^", "", regex) + # Remove trailing $ + regex = re.sub(r"\$$", "", regex) + # Interpret \. as . + regex = re.sub(r"\\\.", ".", regex) + # Interpret .* as * + regex = re.sub(r"\.\*", "*", regex) + # Interpret .+ as * + regex = re.sub(r"\.\+", "*", regex) + return [regex] + + +def capture_one_capturing_group(str_regex: str, value: Any) -> str: + """Captures the one and only capturing group, but ignoring non-capturing gorups.""" + # Non-capturing groups have the form (?:a_non_capturing_group) + capturing_groups = re.compile(r"\((?!\?\:).*?\)") + groups = capturing_groups.findall(str_regex) + if len(groups) == 1: + matches = re.match(groups[0], value) + if not matches: + raise ValueError( + "The replace value doesn't respect the expected capturing group." + f" Expected: {groups[0]}. Got: {value}" + ) + else: + raise ValueError( + "A transform regex should have exactly 1 capturing group in" + f" the transform regex. Got: '{str_regex}'" + ) + return capturing_groups.sub(re.escape(value), str_regex) diff --git a/python/mlcroissant/mlcroissant/_src/core/regex_test.py b/python/mlcroissant/mlcroissant/_src/core/regex_test.py new file mode 100644 index 000000000..3a153f01b --- /dev/null +++ b/python/mlcroissant/mlcroissant/_src/core/regex_test.py @@ -0,0 +1,54 @@ +"""regex_test module.""" + +import pytest + +from mlcroissant._src.core import regex as regex_lib + + +@pytest.mark.parametrize( + ["regex", "output"], + [ + [ + "(?:baz)?xxx(?:foo)?yyy(?:bar)?", + [ + "bazxxxfooyyybar", + "bazxxxfooyyy", + "bazxxxyyybar", + "bazxxxyyy", + "xxxfooyyybar", + "xxxfooyyy", + "xxxyyybar", + "xxxyyy", + ], + ], + [ + "^.+/train/.*\.parquet$", # From a valid regex... + [ + "*/train/*.parquet", # ...to a valid glob pattern. + ], + ], + ], +) +def test_regex_to_glob(regex: str, output: list[str]): + assert regex_lib.regex_to_glob(regex) == output + + +def test_capture_one_capturing_group(): + # The value does not match: + with pytest.raises(ValueError): + regex_lib.capture_one_capturing_group( + "default/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING" + ) + + # Too many capturing groups: + with pytest.raises(ValueError): + regex_lib.capture_one_capturing_group( + "(default)/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING" + ) + + assert ( + regex_lib.capture_one_capturing_group( + "default/(?:partial-)?(train|test)/.+parquet$", "train" + ) + == "default/(?:partial-)?train/.+parquet$" + ) diff --git a/python/mlcroissant/mlcroissant/_src/datasets.py b/python/mlcroissant/mlcroissant/_src/datasets.py index 488efe86b..fb8c942b8 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets.py +++ b/python/mlcroissant/mlcroissant/_src/datasets.py @@ -4,7 +4,6 @@ from collections.abc import Mapping import dataclasses -import re import typing from typing import Any @@ -12,6 +11,7 @@ from etils import epath import networkx as nx +from mlcroissant._src.core import regex as regex_lib from mlcroissant._src.core.context import Context from mlcroissant._src.core.graphs import utils as graphs_utils from mlcroissant._src.core.issues import ValidationError @@ -236,25 +236,6 @@ def _find_data_field_to_filter( ) -def _regex_to_glob(regex: str) -> str: - """Converts a regular expression to a blob pattern by unescaping regex syntax. - - Warning: this is based on manual heuristics to convert a regular expression to a - glob expression. - """ - # Remove starting ^ - regex = re.sub(r"^\^", "", regex) - # Remove trailing $ - regex = re.sub(r"\$$", "", regex) - # Interpret \. as . - regex = re.sub(r"\\\.", ".", regex) - # Interpret .* as * - regex = re.sub(r"\.\*", "*", regex) - # Interpret .+ as * - regex = re.sub(r"\.\+", "*", regex) - return regex - - def _regex_from_value(field: Field, value: Any): """Creates a regular expression by injecting the value in the transformation.""" transforms = field.source.transforms @@ -266,17 +247,7 @@ def _regex_from_value(field: Field, value: Any): raise NotImplementedError(error) transform = transforms[0] if str_regex := transform.regex: - capturing_groups = re.compile(r"\(.*\)") - groups = capturing_groups.findall(str_regex) - if len(groups) == 1: - # Check that the value respects the expected capturing group: - re.match(groups[0], value) - else: - raise ValueError( - "A transform regex should have exactly 1 capturing group in" - f" the transform regex. Got: '{str_regex}'" - ) - return capturing_groups.sub(re.escape(value), str_regex) + return regex_lib.capture_one_capturing_group(str_regex, value) raise NotImplementedError(error) @@ -313,12 +284,13 @@ def _propagate_includes(field: Field, operations: nx.Graph[Operation], new_regex filename_pattern = pattern.split("/") if len(filename_pattern) <= 1: raise NotImplementedError() - filename = _regex_to_glob(new_regex) - new_pattern = filename_pattern[:-1] + [filename] - new_includes.append("/".join(new_pattern)) + filenames = regex_lib.regex_to_glob(new_regex) + for filename in filenames: + new_pattern = filename_pattern[:-1] + [filename] + new_includes.append("/".join(new_pattern)) node.includes = new_includes elif source_type == FileProperty.fullpath: - node.includes = [_regex_to_glob(new_regex) for _ in includes] + node.includes = regex_lib.regex_to_glob(new_regex) else: raise NotImplementedError(error)