Skip to content

Commit

Permalink
Handle non-capturing groups in regex transforms.
Browse files Browse the repository at this point in the history
  • Loading branch information
marcenacp committed Nov 29, 2024
1 parent 4e30d0d commit 2a5ee7e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 35 deletions.
88 changes: 88 additions & 0 deletions python/mlcroissant/mlcroissant/_src/core/regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from collections.abc import Iterable
import itertools
import re
from typing import Any


def regex_to_glob(regex: str | list[str]) -> list[str]:
"""Converts a regular expression to several glob patterns.
The function applies several transformations to achieve this:
- Expand all non-capturing groups to possibly several expression. For example,
for Hugging Face croissants, "default/(?:partial-)?train/.+parquet$" would become
["default/train/.+parquet$", "default/partial-train/.+parquet$"] to match both with
and without the `partial-` prefix.
- Convert to a glob pattern using some heuristics.
"""
if isinstance(regex, str):
results = [regex]
for fn in [_expand_non_capturing_groups, _regex_to_glob_for_str]:
results = list(itertools.chain.from_iterable(fn(result) for result in results))
return results


def _expand_non_capturing_groups(regex: str) -> Iterable[str]:
if "(?:" not in regex:
# There is no non-capturing group:
return [regex]

# Find all capturing groups:
pattern = r"\(\?:.*?\)\?|[^()]+"
strings = re.findall(pattern, regex)
if not strings:
raise ValueError("the string should not be empty")
subregex = "".join(strings[1:])
# Recursively construct the results for the sub-regex:
results = _expand_non_capturing_groups(subregex)
string = strings[0]
if string.startswith("(?:") and string.endswith(")?"):
# Append the inside of the non-capturing group:
string = string[len("(?:") : -len(")?")]
return itertools.chain(
[f"{string}{result}" for result in results],
results,
)
else:
# Append the string itself to each result:
return [f"{string}{result}" for result in results]


def _regex_to_glob_for_str(regex: str) -> Iterable[str]:
"""Converts a regular expression to a glob pattern by unescaping regex syntax.
Warning: this is based on manual heuristics to convert a regular expression to a
glob expression.
"""
# Remove starting ^
regex = re.sub(r"^\^", "", regex)
# Remove trailing $
regex = re.sub(r"\$$", "", regex)
# Interpret \. as .
regex = re.sub(r"\\\.", ".", regex)
# Interpret .* as *
regex = re.sub(r"\.\*", "*", regex)
# Interpret .+ as *
regex = re.sub(r"\.\+", "*", regex)
return [regex]


def capture_one_capturing_group(str_regex: str, value: Any) -> str:
"""Captures the one and only capturing group, but ignoring non-capturing gorups."""
# Non-capturing groups have the form (?:a_non_capturing_group)
capturing_groups = re.compile(r"\((?!\?\:).*?\)")
groups = capturing_groups.findall(str_regex)
if len(groups) == 1:
matches = re.match(groups[0], value)
if not matches:
raise ValueError(
"The replace value doesn't respect the expected capturing group."
f" Expected: {groups[0]}. Got: {value}"
)
else:
raise ValueError(
"A transform regex should have exactly 1 capturing group in"
f" the transform regex. Got: '{str_regex}'"
)
return capturing_groups.sub(re.escape(value), str_regex)
54 changes: 54 additions & 0 deletions python/mlcroissant/mlcroissant/_src/core/regex_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""regex_test module."""

import pytest

from mlcroissant._src.core import regex as regex_lib


@pytest.mark.parametrize(
["regex", "output"],
[
[
"(?:baz)?xxx(?:foo)?yyy(?:bar)?",
[
"bazxxxfooyyybar",
"bazxxxfooyyy",
"bazxxxyyybar",
"bazxxxyyy",
"xxxfooyyybar",
"xxxfooyyy",
"xxxyyybar",
"xxxyyy",
],
],
[
"^.+/train/.*\.parquet$", # From a valid regex...
[
"*/train/*.parquet", # ...to a valid glob pattern.
],
],
],
)
def test_regex_to_glob(regex: str, output: list[str]):
assert regex_lib.regex_to_glob(regex) == output


def test_capture_one_capturing_group():
# The value does not match:
with pytest.raises(ValueError):
regex_lib.capture_one_capturing_group(
"default/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING"
)

# Too many capturing groups:
with pytest.raises(ValueError):
regex_lib.capture_one_capturing_group(
"(default)/(?:partial-)?(train|test)/.+parquet$", "NOT MATCHING"
)

assert (
regex_lib.capture_one_capturing_group(
"default/(?:partial-)?(train|test)/.+parquet$", "train"
)
== "default/(?:partial-)?train/.+parquet$"
)
42 changes: 7 additions & 35 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from collections.abc import Mapping
import dataclasses
import re
import typing
from typing import Any

from absl import logging
from etils import epath
import networkx as nx

from mlcroissant._src.core import regex as regex_lib
from mlcroissant._src.core.context import Context
from mlcroissant._src.core.graphs import utils as graphs_utils
from mlcroissant._src.core.issues import ValidationError
Expand Down Expand Up @@ -236,25 +236,6 @@ def _find_data_field_to_filter(
)


def _regex_to_glob(regex: str) -> str:
"""Converts a regular expression to a blob pattern by unescaping regex syntax.
Warning: this is based on manual heuristics to convert a regular expression to a
glob expression.
"""
# Remove starting ^
regex = re.sub(r"^\^", "", regex)
# Remove trailing $
regex = re.sub(r"\$$", "", regex)
# Interpret \. as .
regex = re.sub(r"\\\.", ".", regex)
# Interpret .* as *
regex = re.sub(r"\.\*", "*", regex)
# Interpret .+ as *
regex = re.sub(r"\.\+", "*", regex)
return regex


def _regex_from_value(field: Field, value: Any):
"""Creates a regular expression by injecting the value in the transformation."""
transforms = field.source.transforms
Expand All @@ -266,17 +247,7 @@ def _regex_from_value(field: Field, value: Any):
raise NotImplementedError(error)
transform = transforms[0]
if str_regex := transform.regex:
capturing_groups = re.compile(r"\(.*\)")
groups = capturing_groups.findall(str_regex)
if len(groups) == 1:
# Check that the value respects the expected capturing group:
re.match(groups[0], value)
else:
raise ValueError(
"A transform regex should have exactly 1 capturing group in"
f" the transform regex. Got: '{str_regex}'"
)
return capturing_groups.sub(re.escape(value), str_regex)
return regex_lib.capture_one_capturing_group(str_regex, value)
raise NotImplementedError(error)


Expand Down Expand Up @@ -313,12 +284,13 @@ def _propagate_includes(field: Field, operations: nx.Graph[Operation], new_regex
filename_pattern = pattern.split("/")
if len(filename_pattern) <= 1:
raise NotImplementedError()
filename = _regex_to_glob(new_regex)
new_pattern = filename_pattern[:-1] + [filename]
new_includes.append("/".join(new_pattern))
filenames = regex_lib.regex_to_glob(new_regex)
for filename in filenames:
new_pattern = filename_pattern[:-1] + [filename]
new_includes.append("/".join(new_pattern))
node.includes = new_includes
elif source_type == FileProperty.fullpath:
node.includes = [_regex_to_glob(new_regex) for _ in includes]
node.includes = regex_lib.regex_to_glob(new_regex)
else:
raise NotImplementedError(error)

Expand Down

0 comments on commit 2a5ee7e

Please sign in to comment.