diff --git a/python/mlcroissant/mlcroissant/_src/beam.py b/python/mlcroissant/mlcroissant/_src/beam.py index 0657e004..91cdd1df 100644 --- a/python/mlcroissant/mlcroissant/_src/beam.py +++ b/python/mlcroissant/mlcroissant/_src/beam.py @@ -3,46 +3,22 @@ from __future__ import annotations from collections.abc import Mapping -import functools -import typing -from typing import Any, Callable +from typing import Any from etils import epath from mlcroissant._src.datasets import Dataset from mlcroissant._src.datasets import Filters -if typing.TYPE_CHECKING: - import apache_beam as beam - - -def _beam_ptransform_fn(fn: Callable[..., Any]) -> Callable[..., Any]: - """Lazy version of `@beam.ptransform_fn` in case Beam is not installed.""" - lazy_decorated_fn = None - - @functools.wraps(fn) - def decorated(*args, **kwargs): - nonlocal lazy_decorated_fn - # Actually decorate the function only the first time it is called - if lazy_decorated_fn is None: - import apache_beam as beam - - lazy_decorated_fn = beam.ptransform_fn(fn) - return lazy_decorated_fn(*args, **kwargs) - - return decorated - -@_beam_ptransform_fn def ReadFromCroissant( - pipeline: beam.Pipeline, *, jsonld: epath.PathLike | Mapping[str, Any], record_set: str, mapping: Mapping[str, epath.PathLike] | None = None, filters: Filters | None = None, ): - """Returns an Apache Beam reader to generate the dataset using e.g. Spark. + """Returns an Apache Beam PCollection to generate the dataset using e.g. Spark. Example of usage: @@ -65,7 +41,6 @@ def ReadFromCroissant( Face datasets, so it raises an error if the dataset is not a Hugging Face dataset. Args: - pipeline: A Beam pipeline (automatically set). jsonld: A JSON object or a path to a Croissant file (URL, str or pathlib.Path). record_set: The name of the record set to generate. mapping: Mapping filename->filepath as a Python dict[str, str] to handle manual @@ -85,7 +60,4 @@ def ReadFromCroissant( A ValueError if the dataset is not streamable. """ dataset = Dataset(jsonld=jsonld, mapping=mapping) - return dataset.records(record_set, filters=filters).beam_reader( - pipeline, - filters=filters, - ) + return dataset.records(record_set, filters=filters).beam_reader() diff --git a/python/mlcroissant/mlcroissant/_src/datasets.py b/python/mlcroissant/mlcroissant/_src/datasets.py index 064b662a..eba044b3 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets.py +++ b/python/mlcroissant/mlcroissant/_src/datasets.py @@ -4,7 +4,6 @@ from collections.abc import Mapping import dataclasses -import typing from typing import Any from absl import logging @@ -29,9 +28,6 @@ from mlcroissant._src.structure_graph.nodes.metadata import Metadata from mlcroissant._src.structure_graph.nodes.source import FileProperty -if typing.TYPE_CHECKING: - import apache_beam as beam - Filters = Mapping[str, Any] @@ -176,17 +172,14 @@ def __iter__(self): record_set=self.record_set, operations=operations ) - def beam_reader( - self, pipeline: beam.Pipeline, filters: Mapping[str, Any] | None = None - ): + def beam_reader(self): """See ReadFromCroissant docstring.""" operations = self._filter_interesting_operations(self.filters) execute_downloads(operations) return execute_operations_in_beam( - pipeline=pipeline, record_set=self.record_set, operations=operations, - filters=filters or self.filters, + filters=self.filters, ) def _filter_interesting_operations(self, filters: Filters | None) -> Operations: diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py b/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py index 340dff58..cc6e9b5a 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py @@ -7,7 +7,6 @@ import functools import json import sys -import typing from typing import Any, Generator from absl import logging @@ -21,9 +20,6 @@ from mlcroissant._src.operation_graph.operations.download import Download from mlcroissant._src.operation_graph.operations.read import Read -if typing.TYPE_CHECKING: - import apache_beam as beam - ElementWithIndex = tuple[int, Any] @@ -129,7 +125,6 @@ def read_all_files(): def execute_operations_in_beam( - pipeline: beam.Pipeline, record_set: str, operations: Operations, filters: Mapping[str, Any] | None = None, @@ -181,19 +176,15 @@ def execute_operations_in_beam( for operation in operations_in_memory: # If there is no FilterFiles, we return the PCollection without parallelization. if operation == target: - return ( - pipeline - | beam.Create([(0, *operation.inputs)]) - | _beam_operation_with_index(operation, sys.maxsize, stage_prefix) + return beam.Create([(0, *operation.inputs)]) | _beam_operation_with_index( + operation, sys.maxsize, stage_prefix ) else: operation(set_output_in_memory=True) files = filter_files.output # even for large datasets, this can be handled in RAM. # We first shard by file and assign a shard_index. - pipeline = pipeline | f"{stage_prefix} Shard by files with index" >> beam.Create( - enumerate(files) - ) + pipeline = beam.Create(enumerate(files)) num_shards = len(files) if not num_shards: raise ValueError(