Skip to content

Commit

Permalink
Use filters to define stage names.
Browse files Browse the repository at this point in the history
  • Loading branch information
ccl-core committed Nov 26, 2024
1 parent fd25b64 commit 196278b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
6 changes: 2 additions & 4 deletions python/mlcroissant/mlcroissant/_src/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def ReadFromCroissant(
record_set: str,
mapping: Mapping[str, epath.PathLike] | None = None,
filters: Filters | None = None,
stage_prefix: str = "",
):
"""Returns an Apache Beam reader to generate the dataset using e.g. Spark.
Expand Down Expand Up @@ -76,8 +75,6 @@ def ReadFromCroissant(
filters: A dictionary mapping a field ID to the value we want to filter in. For
example, when writing {'data/split': 'train'}, we want to keep all records
whose field `data/split` takes the value `train`.
stage_prefix: Optional string which will be prepended to stage names in the beam
pipeline for better readibility.
Returns:
A Beam PCollection with all the records where each element contains a tuple with
Expand All @@ -89,5 +86,6 @@ def ReadFromCroissant(
"""
dataset = Dataset(jsonld=jsonld, mapping=mapping)
return dataset.records(record_set, filters=filters).beam_reader(
pipeline, stage_prefix=stage_prefix
pipeline,
filters=filters,
)
4 changes: 2 additions & 2 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ def __iter__(self):
record_set=self.record_set, operations=operations
)

def beam_reader(self, pipeline: beam.Pipeline, stage_prefix: str = ""):
def beam_reader(self, pipeline: beam.Pipeline, filters: Mapping[str, Any]):
"""See ReadFromCroissant docstring."""
operations = self._filter_interesting_operations(self.filters)
execute_downloads(operations)
return execute_operations_in_beam(
pipeline=pipeline,
record_set=self.record_set,
operations=operations,
stage_prefix=stage_prefix,
filters=self.filters,
)

def _filter_interesting_operations(self, filters: Filters | None) -> Operations:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

from collections.abc import Mapping
import concurrent.futures
import functools
import json
import sys
import typing
from typing import Any, Generator
Expand Down Expand Up @@ -130,7 +132,7 @@ def execute_operations_in_beam(
pipeline: beam.Pipeline,
record_set: str,
operations: Operations,
stage_prefix: str = "",
filters: Mapping[str, Any] | None = None,
):
"""See ReadFromCroissant docstring."""
import apache_beam as beam
Expand All @@ -143,6 +145,7 @@ def execute_operations_in_beam(
# We use the FilterFiles operation to parallelize operations. If there's no
# FilterFile operation, we set it to `target`.
filter_files = _find_filter_files(operations, target)
stage_prefix = f"{record_set} " + json.dumps(filters) if filters else "no filter"

# In memory = all operations that are not between FilterFiles and the target.
# In Beam = all operations that are between FilterFiles and the target.
Expand Down Expand Up @@ -187,7 +190,6 @@ def execute_operations_in_beam(
operation(set_output_in_memory=True)

files = filter_files.output # even for large datasets, this can be handled in RAM.

# We first shard by file and assign a shard_index.
pipeline = pipeline | f"{stage_prefix} Shard by files with index" >> beam.Create(
enumerate(files)
Expand Down

0 comments on commit 196278b

Please sign in to comment.