diff --git a/python/mlcroissant/mlcroissant/_src/beam.py b/python/mlcroissant/mlcroissant/_src/beam.py index 0657e004..b96b5c73 100644 --- a/python/mlcroissant/mlcroissant/_src/beam.py +++ b/python/mlcroissant/mlcroissant/_src/beam.py @@ -16,33 +16,14 @@ import apache_beam as beam -def _beam_ptransform_fn(fn: Callable[..., Any]) -> Callable[..., Any]: - """Lazy version of `@beam.ptransform_fn` in case Beam is not installed.""" - lazy_decorated_fn = None - - @functools.wraps(fn) - def decorated(*args, **kwargs): - nonlocal lazy_decorated_fn - # Actually decorate the function only the first time it is called - if lazy_decorated_fn is None: - import apache_beam as beam - - lazy_decorated_fn = beam.ptransform_fn(fn) - return lazy_decorated_fn(*args, **kwargs) - - return decorated - - -@_beam_ptransform_fn def ReadFromCroissant( - pipeline: beam.Pipeline, *, jsonld: epath.PathLike | Mapping[str, Any], record_set: str, mapping: Mapping[str, epath.PathLike] | None = None, filters: Filters | None = None, ): - """Returns an Apache Beam reader to generate the dataset using e.g. Spark. + """Returns an Apache Beam PCollection to generate the dataset using e.g. Spark. Example of usage: @@ -65,7 +46,6 @@ def ReadFromCroissant( Face datasets, so it raises an error if the dataset is not a Hugging Face dataset. Args: - pipeline: A Beam pipeline (automatically set). jsonld: A JSON object or a path to a Croissant file (URL, str or pathlib.Path). record_set: The name of the record set to generate. mapping: Mapping filename->filepath as a Python dict[str, str] to handle manual @@ -85,7 +65,4 @@ def ReadFromCroissant( A ValueError if the dataset is not streamable. """ dataset = Dataset(jsonld=jsonld, mapping=mapping) - return dataset.records(record_set, filters=filters).beam_reader( - pipeline, - filters=filters, - ) + return dataset.records(record_set, filters=filters).beam_reader(filters=filters) diff --git a/python/mlcroissant/mlcroissant/_src/datasets.py b/python/mlcroissant/mlcroissant/_src/datasets.py index 064b662a..fb0ffe2f 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets.py +++ b/python/mlcroissant/mlcroissant/_src/datasets.py @@ -176,14 +176,11 @@ def __iter__(self): record_set=self.record_set, operations=operations ) - def beam_reader( - self, pipeline: beam.Pipeline, filters: Mapping[str, Any] | None = None - ): + def beam_reader(self, filters: Mapping[str, Any] | None = None): """See ReadFromCroissant docstring.""" operations = self._filter_interesting_operations(self.filters) execute_downloads(operations) return execute_operations_in_beam( - pipeline=pipeline, record_set=self.record_set, operations=operations, filters=filters or self.filters, diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py b/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py index 340dff58..2f863776 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py @@ -129,7 +129,6 @@ def read_all_files(): def execute_operations_in_beam( - pipeline: beam.Pipeline, record_set: str, operations: Operations, filters: Mapping[str, Any] | None = None, @@ -181,19 +180,15 @@ def execute_operations_in_beam( for operation in operations_in_memory: # If there is no FilterFiles, we return the PCollection without parallelization. if operation == target: - return ( - pipeline - | beam.Create([(0, *operation.inputs)]) - | _beam_operation_with_index(operation, sys.maxsize, stage_prefix) + return beam.Create([(0, *operation.inputs)]) | _beam_operation_with_index( + operation, sys.maxsize, stage_prefix ) else: operation(set_output_in_memory=True) files = filter_files.output # even for large datasets, this can be handled in RAM. # We first shard by file and assign a shard_index. - pipeline = pipeline | f"{stage_prefix} Shard by files with index" >> beam.Create( - enumerate(files) - ) + pipeline = beam.Create(enumerate(files)) num_shards = len(files) if not num_shards: raise ValueError(