Skip to content

Commit

Permalink
Drop all useless operations when we use filtering on a field - so we …
Browse files Browse the repository at this point in the history
…know its value in advance.
  • Loading branch information
marcenacp committed Nov 29, 2024
1 parent d1e81bd commit 4133e02
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
18 changes: 17 additions & 1 deletion python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,23 @@ def _filter_interesting_operations(self, filters: Filters | None) -> Operations:
field, value = _find_data_field_to_filter(filters, interesting_operations)
new_regex = _regex_from_value(field, value)
_propagate_includes(field, interesting_operations, new_regex)
return interesting_operations # pytype: disable=bad-return-type
# The value of `field` is now entirely known so we can remove any operation
# needed to compute it, i.e. all operations involved in a potential join:
join_uuid = field.references.uuid
graph = field.ctx.graph
if join_uuid:
join_node = next(node for node in graph if node.uuid == join_uuid)
unneeded_nodes = [
node
for node in graph
if graph.has_edge(node, join_node) or node == join_node
]
interesting_operations = [
o for o in interesting_operations if o.node not in unneeded_nodes
]
return operations.subgraph(
interesting_operations
) # pytype: disable=bad-return-type


def _find_data_field_to_filter(
Expand Down
9 changes: 9 additions & 0 deletions python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ def test_nonhermetic_loading(version, dataset_name, record_set_name, num_records
["huggingface-c4/metadata.json", "data", 1, {"data/variant": "en"}],
["huggingface-levanti/metadata.json", "levanti_train", 10, None],
["huggingface-open-hermes/metadata.json", "default", 3, None],
# This dataset will timeout if the following feature is broken: mlcroissant
# yields examples by downloading parquet files one by one. mlcroissant should
# not download all parquet files upfront.
[
"https://huggingface.co/api/datasets/HuggingFaceFW/fineweb/croissant",
"CC-MAIN-2024-10",
1,
{"CC-MAIN-2024-10/split": "train"},
],
],
)
def test_nonhermetic_loading_1_0(dataset_name, record_set_name, num_records, filters):
Expand Down

0 comments on commit 4133e02

Please sign in to comment.