Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle non-capturing groups in regex transforms (partial-train/*.parquet). #774

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions python/mlcroissant/mlcroissant/_src/core/regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from collections.abc import Iterable
import itertools
import re
from typing import Any


def regex_to_glob(regex: str | list[str]) -> list[str]:
"""Converts a regular expression to several glob patterns.

The function applies several transformations to achieve this:

- Expand all non-capturing groups to possibly several expression. For example,
for Hugging Face croissants, "default/(?:partial-)?train/.+parquet$" would become
["default/train/.+parquet$", "default/partial-train/.+parquet$"] to match both with
and without the `partial-` prefix.

- Convert to a glob pattern using some heuristics.
"""
if isinstance(regex, str):
results = [regex]
for fn in [_expand_non_capturing_groups, _regex_to_glob_for_str]:
results = list(itertools.chain.from_iterable(fn(result) for result in results))
return results


def _expand_non_capturing_groups(regex: str) -> Iterable[str]:
if "(?:" not in regex:
# There is no non-capturing group:
return [regex]

# Find all capturing groups:
pattern = r"\(\?:.*?\)\?|[^()]+"
strings = re.findall(pattern, regex)
if not strings:
raise ValueError("the string should not be empty")
subregex = "".join(strings[1:])
# Recursively construct the results for the sub-regex:
results = _expand_non_capturing_groups(subregex)
string = strings[0]
if string.startswith("(?:") and string.endswith(")?"):
# Append the inside of the non-capturing group:
string = string[len("(?:") : -len(")?")]
return itertools.chain(
[f"{string}{result}" for result in results],
results,
)
else:
# Append the string itself to each result:
return [f"{string}{result}" for result in results]


def _regex_to_glob_for_str(regex: str) -> Iterable[str]:
"""Converts a regular expression to a glob pattern by unescaping regex syntax.

Warning: this is based on manual heuristics to convert a regular expression to a
glob expression.
"""
# Remove starting ^
regex = re.sub(r"^\^", "", regex)
# Remove trailing $
regex = re.sub(r"\$$", "", regex)
# Interpret \. as .
regex = re.sub(r"\\\.", ".", regex)
# Interpret .* as *
regex = re.sub(r"\.\*", "*", regex)
# Interpret .+ as *
regex = re.sub(r"\.\+", "*", regex)
return [regex]


def capture_one_capturing_group(str_regex: str, value: Any) -> str:
"""Captures the one and only capturing group, but ignoring non-capturing gorups."""
# Non-capturing groups have the form (?:a_non_capturing_group)
capturing_groups = re.compile(r"\((?!\?\:).*?\)")
groups = capturing_groups.findall(str_regex)
if len(groups) == 1:
matches = re.match(groups[0], value)
if not matches:
raise ValueError(
"The replace value doesn't respect the expected capturing group."
f" Expected: {groups[0]}. Got: {value}"
)
else:
raise ValueError(
"A transform regex should have exactly 1 capturing group in"
f" the transform regex. Got: '{str_regex}'"
)
return capturing_groups.sub(re.escape(value), str_regex)
54 changes: 54 additions & 0 deletions python/mlcroissant/mlcroissant/_src/core/regex_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""regex_test module."""

import pytest

from mlcroissant._src.core import regex as regex_lib


@pytest.mark.parametrize(
["regex", "output"],
[
[
"(?:baz)?xxx(?:foo)?yyy(?:bar)?",
[
"bazxxxfooyyybar",
"bazxxxfooyyy",
"bazxxxyyybar",
"bazxxxyyy",
"xxxfooyyybar",
"xxxfooyyy",
"xxxyyybar",
"xxxyyy",
],
],
[
"^.+/train/.*\.parquet$", # From a valid regex...
[
"*/train/*.parquet", # ...to a valid glob pattern.
],
],
],
)
def test_regex_to_glob(regex: str, output: list[str]):
assert regex_lib.regex_to_glob(regex) == output


def test_capture_one_capturing_group():
# The value does not match:
with pytest.raises(ValueError):
regex_lib.capture_one_capturing_group(
"default/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING"
)

# Too many capturing groups:
with pytest.raises(ValueError):
regex_lib.capture_one_capturing_group(
"(default)/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING"
)

assert (
regex_lib.capture_one_capturing_group(
"default/(?:partial-)?(train|test)/.+parquet$", "train"
)
== "default/(?:partial-)?train/.+parquet$"
)
42 changes: 7 additions & 35 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from collections.abc import Mapping
import dataclasses
import re
import typing
from typing import Any

from absl import logging
from etils import epath
import networkx as nx

from mlcroissant._src.core import regex as regex_lib
from mlcroissant._src.core.context import Context
from mlcroissant._src.core.graphs import utils as graphs_utils
from mlcroissant._src.core.issues import ValidationError
Expand Down Expand Up @@ -236,25 +236,6 @@ def _find_data_field_to_filter(
)


def _regex_to_glob(regex: str) -> str:
"""Converts a regular expression to a blob pattern by unescaping regex syntax.

Warning: this is based on manual heuristics to convert a regular expression to a
glob expression.
"""
# Remove starting ^
regex = re.sub(r"^\^", "", regex)
# Remove trailing $
regex = re.sub(r"\$$", "", regex)
# Interpret \. as .
regex = re.sub(r"\\\.", ".", regex)
# Interpret .* as *
regex = re.sub(r"\.\*", "*", regex)
# Interpret .+ as *
regex = re.sub(r"\.\+", "*", regex)
return regex


def _regex_from_value(field: Field, value: Any):
"""Creates a regular expression by injecting the value in the transformation."""
transforms = field.source.transforms
Expand All @@ -266,17 +247,7 @@ def _regex_from_value(field: Field, value: Any):
raise NotImplementedError(error)
transform = transforms[0]
if str_regex := transform.regex:
capturing_groups = re.compile(r"\(.*\)")
groups = capturing_groups.findall(str_regex)
if len(groups) == 1:
# Check that the value respects the expected capturing group:
re.match(groups[0], value)
else:
raise ValueError(
"A transform regex should have exactly 1 capturing group in"
f" the transform regex. Got: '{str_regex}'"
)
return capturing_groups.sub(re.escape(value), str_regex)
return regex_lib.capture_one_capturing_group(str_regex, value)
raise NotImplementedError(error)


Expand Down Expand Up @@ -313,12 +284,13 @@ def _propagate_includes(field: Field, operations: nx.Graph[Operation], new_regex
filename_pattern = pattern.split("/")
if len(filename_pattern) <= 1:
raise NotImplementedError()
filename = _regex_to_glob(new_regex)
new_pattern = filename_pattern[:-1] + [filename]
new_includes.append("/".join(new_pattern))
filenames = regex_lib.regex_to_glob(new_regex)
for filename in filenames:
new_pattern = filename_pattern[:-1] + [filename]
new_includes.append("/".join(new_pattern))
node.includes = new_includes
elif source_type == FileProperty.fullpath:
node.includes = [_regex_to_glob(new_regex) for _ in includes]
node.includes = regex_lib.regex_to_glob(new_regex)
else:
raise NotImplementedError(error)

Expand Down