Skip to content

Commit

Permalink
Drop all useless operations when we use filtering on a field - so we …
Browse files Browse the repository at this point in the history
…know its value in advance.
  • Loading branch information
marcenacp committed Nov 29, 2024
1 parent d1e81bd commit 7ca7552
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
18 changes: 17 additions & 1 deletion python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,23 @@ def _filter_interesting_operations(self, filters: Filters | None) -> Operations:
field, value = _find_data_field_to_filter(filters, interesting_operations)
new_regex = _regex_from_value(field, value)
_propagate_includes(field, interesting_operations, new_regex)
return interesting_operations # pytype: disable=bad-return-type
# The value of `field` is now entirely known so we can remove any operation
# needed to compute it, i.e. all operations involved in a potential join:
join_uuid = field.references.uuid
graph = field.ctx.graph
if join_uuid:
join_node = next(node for node in graph if node.uuid == join_uuid)
unneeded_nodes = [
node
for node in graph
if graph.has_edge(node, join_node) or node == join_node
]
interesting_operations = [
o for o in interesting_operations if o.node not in unneeded_nodes
]
return operations.subgraph(
interesting_operations
) # pytype: disable=bad-return-type


def _find_data_field_to_filter(
Expand Down
34 changes: 24 additions & 10 deletions python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,27 @@ def load_records_and_test_equality(
f" {record_set_name} --num_records {num_records} --debug --update_output"
f" {filters_command}`"
)
config = _REPOSITORY_FOLDER / "datasets" / version / dataset_name
output_file = config.parent / "output" / f"{record_set_name}.jsonl"
with output_file.open("rb") as f:
lines = f.readlines()
expected_records = [json.loads(line) for line in lines]
if dataset_name.startswith("https://"):
config = dataset_name
expected_records = None
else:
config = _REPOSITORY_FOLDER / "datasets" / version / dataset_name
output_file = config.parent / "output" / f"{record_set_name}.jsonl"
with output_file.open("rb") as f:
lines = f.readlines()
expected_records = [json.loads(line) for line in lines]
if num_records > 0:
assert len(expected_records) == num_records
dataset = datasets.Dataset(config, mapping=mapping)
records = dataset.records(record_set_name, filters=filters)
records = iter(records)
length = 0
for i, record in enumerate(records):
for i in range(num_records):
record = next(records)
if num_records > 0 and i >= num_records:
break
record = record_to_python(record)
assert record == expected_records[i]
length += 1
assert len(expected_records) == length
if expected_records:
assert record == expected_records[i]


def _equal_to_set(expected):
Expand Down Expand Up @@ -248,6 +253,15 @@ def test_nonhermetic_loading(version, dataset_name, record_set_name, num_records
["huggingface-c4/metadata.json", "data", 1, {"data/variant": "en"}],
["huggingface-levanti/metadata.json", "levanti_train", 10, None],
["huggingface-open-hermes/metadata.json", "default", 3, None],
# This dataset will timeout if the following feature is broken: mlcroissant
# yields examples by downloading parquet files one by one. mlcroissant should
# not download all parquet files upfront.
[
"https://huggingface.co/api/datasets/bigcode/the-stack-metadata/croissant",
"default",
1,
{"default/split": "train"},
],
],
)
def test_nonhermetic_loading_1_0(dataset_name, record_set_name, num_records, filters):
Expand Down

0 comments on commit 7ca7552

Please sign in to comment.