diff --git a/.github/workflows/lint_and_tests.yaml b/.github/workflows/lint_and_tests.yaml index e65772f..5c1c5e8 100644 --- a/.github/workflows/lint_and_tests.yaml +++ b/.github/workflows/lint_and_tests.yaml @@ -13,7 +13,7 @@ jobs: max-parallel: 1 matrix: platform: [ubuntu-latest] - python-version: [3.8] + python-version: [3.9] runs-on: ${{ matrix.platform }} @@ -30,6 +30,7 @@ jobs: # Fairseq doesn't install with pip==22.1 we need to upgrade past it. # Also the version on pypi is from before Oct 2020. # wheel is required by fasttext to be installed correctly with recent pip versions + # fairseq+omegaconf do not play nice when using pip > 24.1 run: | python --version python -m pip install --upgrade 'pip>=22.1.2,<24.1' diff --git a/pyproject.toml b/pyproject.toml index b027a02..3788a3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "flit_core.buildapi" name = "stopes" readme = "README.md" authors = [{name = "Facebook AI Research"}] -requires-python = ">=3.8" +requires-python = ">=3.9" dynamic = ["version", "description"] dependencies = [ @@ -15,7 +15,7 @@ dependencies = [ "submitit>=1.4.5", "tqdm", "posix_ipc", - "pyarrow>=13.0.0" + "pyarrow>=16.1.0" ] # zip_safe = false classifiers=[ @@ -75,7 +75,7 @@ classifiers=[ "torchaudio", "scipy", "pandas", - "pyarrow>=13.0.0", + "pyarrow>=16.1.0", "numba", "transformers", "openai-whisper==20230314", diff --git a/stopes/__init__.py b/stopes/__init__.py index 471399a..f9b510f 100644 --- a/stopes/__init__.py +++ b/stopes/__init__.py @@ -14,4 +14,4 @@ """ -__version__ = "2.1.0" +__version__ = "2.2.0" diff --git a/stopes/core/launcher.py b/stopes/core/launcher.py index 03aa53d..d54c434 100644 --- a/stopes/core/launcher.py +++ b/stopes/core/launcher.py @@ -22,8 +22,16 @@ from stopes.core import utils from stopes.core.cache import Cache, MissingCache, NoCache -from stopes.core.jobs_registry.registry import JobsRegistry -from stopes.core.jobs_registry.submitit_slurm_job import SubmititJob +from stopes.core.jobs_registry.registry import JobsRegistry # type: ignore +from stopes.core.jobs_registry.submitit_slurm_job import SubmititJob # type: ignore + + +@dataclasses.dataclass +class SkipValue: + """A value to skip in the array.""" + + expected_result: tp.Any + if tp.TYPE_CHECKING: from stopes.core import StopesModule @@ -222,6 +230,8 @@ def __init__( log_folder: tp.Union[Path, str] = Path("executor_logs"), cluster: str = "local", partition: tp.Optional[str] = None, + qos: tp.Optional[str] = None, + account: tp.Optional[str] = None, supports_mem_spec: bool = True, # some slurm clusters do not support mem_gb, if you set this to False, the Requirements.mem_gb coming from the module will be ignored disable_tqdm: bool = False, # if you don't want tqdm progress bars max_retries: int = 0, @@ -237,6 +247,8 @@ def __init__( - `log_folder` where to store execution logs for each job (default: `executor_logs`) - `cluster`, a submitit cluster spec. `local` to run locally, `slurm` for slurm - `partition`, the slurm partition to use + - `qos`, the slurm QOS (quality-of-service) to use + - `account`, the slurm account to use - `supports_mem_spec`, ignore mem requirements for some cluster - `disable_tqdm`, don't show that fancy progress bar - `max_retries`, how many retries do we want for each job @@ -255,6 +267,8 @@ def __init__( self.log_folder = Path(log_folder) self.cluster = cluster self.partition = partition + self.qos = qos + self.account = account self.supports_mem_spec = supports_mem_spec self.disable_tqdm = disable_tqdm self.progress_bar: tqdm.tqdm = None @@ -265,9 +279,11 @@ def __init__( if throttle: self.throttle = utils.AsyncIPCSemaphore( - name=throttle.shared_name - if throttle.shared_name - else f"/launcher_{getpass.getuser()}_{uuid.uuid4()}", + name=( + throttle.shared_name + if throttle.shared_name + else f"/launcher_{getpass.getuser()}_{uuid.uuid4()}" + ), flags=posix_ipc.O_CREAT, initial_value=throttle.limit, timeout=throttle.timeout, @@ -324,17 +340,24 @@ def _get_executor(self, module: "StopesModule") -> submitit.Executor: name = module.name() module_log_folder = self.log_folder / name module_log_folder.mkdir(parents=True, exist_ok=True) - executor = AutoExecutor(folder=module_log_folder, cluster=self.cluster) + reqs = module.requirements() + executor = AutoExecutor( + folder=module_log_folder, + cluster=self.cluster, + slurm_max_num_timeout=3 if reqs is None else reqs.max_num_timeout, + ) # update launcher params if self.update_parameters: executor.update_parameters(**self.update_parameters) # setup parameters - reqs = module.requirements() + executor.update_parameters( name=module.name(), slurm_partition=self.partition, + slurm_qos=self.qos, + slurm_account=self.account, ) if self.partition and self.cluster == "slurm": executor.update_parameters( @@ -439,7 +462,13 @@ async def _schedule_array( for idx, val in enumerate(value_array): task = Task(module, idx, val, launcher=self) - # first, look up if this iteration has already been cached + # first, look up if this iteration has been skipped + if isinstance(val, SkipValue): + task = task.done_from_cache(val.expected_result) + self.progress_job_end() + tasks.append(task) + continue + # second, look up if this iteration has already been cached try: cached_result = self.cache.get_cache( module, diff --git a/stopes/core/stopes_module.py b/stopes/core/stopes_module.py index e8c84de..c41c0a4 100644 --- a/stopes/core/stopes_module.py +++ b/stopes/core/stopes_module.py @@ -36,6 +36,7 @@ class Requirements: cpus_per_task: int = 5 timeout_min: int = 720 constraint: tp.Optional[str] = None + max_num_timeout: int = 10 class StopesModule(ABC): diff --git a/stopes/modules/bitext/mining/mine_bitext_indexes_utils.py b/stopes/modules/bitext/mining/mine_bitext_indexes_utils.py index 1360cf8..5c29cfc 100644 --- a/stopes/modules/bitext/mining/mine_bitext_indexes_utils.py +++ b/stopes/modules/bitext/mining/mine_bitext_indexes_utils.py @@ -27,7 +27,7 @@ INDICES_FILE_SUFFIX = ".idx" # when requesting more neighbors than elements in the index, FAISS returns -1 -INVALID_INDEX_VALUE = np.uint32(-1) +INVALID_INDEX_VALUE = np.int32(-1) INVALID_INDEX_REPLACEMENT = 0 INVALID_DISTANCES_REPLACEMENT = 2.0 diff --git a/stopes/modules/partitioned_data_mapper.py b/stopes/modules/partitioned_data_mapper.py new file mode 100644 index 0000000..bfba683 --- /dev/null +++ b/stopes/modules/partitioned_data_mapper.py @@ -0,0 +1,499 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import abc +import copy +import gc +import getpass +import inspect +import logging +import math +import typing as tp +from abc import ABC, abstractmethod +from contextlib import nullcontext +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import pyarrow as pa +import submitit + +from stopes.core.launcher import SkipValue +from stopes.core.stopes_module import Requirements, StopesModule +from stopes.utils.sharding.abstract_shards import ( + BatchType, + InputShardingConfig, + OutputDatasetConfig, + PartitionedDataMapperState, + Shard, + batch_length, + batch_tail, + batch_to_table, + concat_batches, +) +from stopes.utils.sharding.text_shards import TextShard + + +def get_class_hierarchy(cls) -> tp.List[tp.Any]: + classes = inspect.getmro(cls) + return [cls for cls in classes if cls not in [object, abc.ABC, BatchMapper]] + + +def source_code(cls) -> str: + try: + return inspect.getsource(cls) + except Exception: + return "" + + +def get_class_hierarchy_code(cls) -> tp.Dict[str, str]: + classes = get_class_hierarchy(cls) + return {repr(cls): source_code(cls) for cls in classes} + + +@dataclass +class PartitionedDataMapperConfig: + input_dataset_config: InputShardingConfig + output_dataset_config: OutputDatasetConfig + + def __post_init__(self): + # to propagate parquet partitions + if getattr(self.output_dataset_config, "keep_same_partitioning", False): + assert hasattr(self.input_dataset_config, "partition_columns") + assert hasattr(self.output_dataset_config, "partition_columns") + self.output_dataset_config.partition_columns = getattr( + self.input_dataset_config, "partition_columns" + ) + + +class PartitionedDataMapper(StopesModule): + """ + The main goal of the `PartitionedDataMapper` (and other classes around it) + is to create an efficient abstraction layer for Batch Processing in Stopes. + In essence, we want to disentangle the batch based transformation logic from data IO. + Thus, + - `PartitionedDataMapper` takes care of all partitioning logic, IO operations, ... and when it's used for sub-classing, + - the developer needs only to implement the logic that define a function that transforms a batch to a batch + - `def get_batch_mapper(self) -> tp.Callable[[tp.Optional[BatchType]], tp.Optional[BatchType]]:` + - Here, we need to return a callable that could be applied several times in mini-batching loop without reinitazing it + - One could subclass `class BatchMapper(ABC):` to structure such callable objects typically for model inference case. + + This pattern looks similar to [`Dask.DataFrame.map_partitions`](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.map_partitions.html), + [`pyspark.RDD.mapPartitions`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.RDD.mapPartitions.html) or + [`ray.data.Dataset.map_batches`](https://docs.ray.io/en/latest/data/api/doc/ray.data.Dataset.map_batches.html). + + * In particular, this disentangle would allow to reuse Stopes batch transformer code interchangeably for any such higher level framework. + * Chaining several transformation over on batch inside a single Stopes Model should be also easier with this approach. + + """ + + state: tp.Optional[PartitionedDataMapperState] + iteration_index: int + written_batches_index: int + + def __init__( + self, + config: PartitionedDataMapperConfig, + state: tp.Optional[PartitionedDataMapperState] = None, + shards_to_skip: tp.Optional[tp.Dict[int, tp.List[Path]]] = None, + ) -> None: + super().__init__(config, PartitionedDataMapperConfig) + + self.input_dataset_config = config.input_dataset_config + self.output_dataset_config = config.output_dataset_config + self.logger = logging.getLogger(self.__class__.__name__) + self.state = state + self.shards_to_skip = shards_to_skip + + if ( + self.output_dataset_config.expected_schema is None + and self.output_dataset_config.validate_schema + ): + self.output_dataset_config.expected_schema = self.guessed_expected_schema() + + if ( + self.output_dataset_config.expected_schema is None + or not self.output_dataset_config.validate_schema + ): + self.logger.warning("Output schema will NOT be validated") + + def array(self) -> tp.List[Shard]: + if self.state is not None: + return [self.state.iteration_value] + + full_array = self.input_dataset_config.make_shards() + if self.shards_to_skip is not None: + self.logger.warn(f"Adding shards to skip: {len(self.shards_to_skip)}") + for idx, shards_to_skip in self.shards_to_skip.items(): + full_array[idx] = SkipValue(shards_to_skip) # type: ignore + return full_array + + @abstractmethod + def get_batch_mapper( + self, + ) -> tp.Callable[[tp.Optional[BatchType]], tp.Optional[BatchType]]: + # for stateful mapping one can follow the this pattern + # `return CallableBatchMapperClass(self.my_batch_mapper_config)` + # with `CallableBatchMapperClass.__call__(tp.Optional[BatchType]]) -> BatchType` + return lambda batch: batch + + def get_metadata(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: + return { + **self.get_common_metadata(*args, **kwargs), + **self.get_custom_metadata(*args, **kwargs), + } + + def get_common_metadata(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: + meta = {} + batch_mapper = kwargs.get("batch_mapper", None) + + meta["config"] = self.config + if batch_mapper: + meta["batch_mapper_class"] = get_class_hierarchy_code( + batch_mapper.__class__ + ) + try: + meta["batch_mapper_code"] = inspect.getsource(batch_mapper) + except Exception: + meta["batch_mapper_code"] = "" + + meta["username"] = getpass.getuser() + meta["save_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + # TODO: add pip/git/conda info + return meta + + @abstractmethod + def get_custom_metadata(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: + return {} + + def guessed_expected_schema( + self, num_rows=10, nb_tries=3 + ) -> tp.Optional[pa.Schema]: + + arrays = self.array() + if len(arrays) == 0: + self.logger.warning("Empty ARRAYs detected") + return None + batch_mapper_fn = self.get_batch_mapper() + for i in range(nb_tries): + batch = next( + arrays[i].to_batches( + batch_size=min( + num_rows, self.input_dataset_config.batch_size or num_rows + ), + columns=self.input_dataset_config.columns, + batch_format=self.input_dataset_config.batch_format, + ) + ) + if batch is not None: + output_batch = batch_mapper_fn(batch) + return batch_to_table(output_batch).schema + + self.logger.warning("Resulted batch is empty") + return None + + def run( + self, iteration_value: tp.Optional[tp.Any] = None, iteration_index: int = 0 + ) -> tp.List[Path]: + + if self.state is None: + # try to get state from checkpoint + self.state = self.output_dataset_config.reload_state(iteration_value) # type: ignore + + if self.state is None: + self.state = PartitionedDataMapperState( + iteration_index=iteration_index, + iteration_value=iteration_value, # type: ignore + written_batches_index=-1, + intermediate_index=0, + input_rows_written=0, + ) + iteration_value = self.state.iteration_value + iteration_index = self.state.iteration_index + + if iteration_value is None: + self.logger.warning( + f"Input `None` iteration_value, skipping iteration {iteration_index}" + ) + return [] + + if not isinstance(iteration_value, Shard): + raise ValueError("Partitioned dataset should be defined") + + shard: Shard = iteration_value + # we need to initialize the batch_mapper_fn each to to avoid sharing it state between workers + batch_mapper_fn = self.get_batch_mapper() + metadata = self.get_metadata(**{"batch_mapper": batch_mapper_fn}) + output_batch_size = self.output_dataset_config.batch_size + # transforming batch split by mini-batches (typical for GPU inference) + mini_batch_results: tp.List[BatchType] = [] + + nb_current_samples = 0 + output_written_paths = [] + + # FIXME: text_shard has been entered implicitly in to_batches(), change + # this to have a consistent code among all types of shards + if isinstance(shard, TextShard): + _context = nullcontext() + else: + _context = shard # type: ignore + + current_input_rows_processed = 0 + + if hasattr(shard, "skip_n_rows") and shard.skip_n_rows and output_batch_size: + # we need to advance intermediate_index to make sure to not overwrite existing files + self.state.intermediate_index = ( + math.ceil(shard.skip_n_rows / output_batch_size) + 1 + ) + + with _context: + for batch_idx, batch in enumerate( + shard.to_batches( + batch_size=self.input_dataset_config.batch_size, + columns=self.input_dataset_config.columns, + batch_format=self.input_dataset_config.batch_format, + ) + ): + if batch_idx <= self.state.written_batches_index: + continue + + input_batch_len = batch_length(batch) + + left_to_skip = 0 + # skip from the state + if ( + current_input_rows_processed + input_batch_len + <= self.state.input_rows_written + ): + current_input_rows_processed += input_batch_len + continue + else: + left_to_skip = ( + self.state.input_rows_written - current_input_rows_processed + ) + + # skip from shard config + if hasattr(shard, "skip_n_rows") and shard.skip_n_rows > 0: + if ( + current_input_rows_processed + input_batch_len + <= shard.skip_n_rows + ): + current_input_rows_processed += input_batch_len + continue + else: + # shard config overrides state + left_to_skip = shard.skip_n_rows - current_input_rows_processed + + if left_to_skip > 0: + batch = batch_tail(batch, input_batch_len - left_to_skip) + + transformed_batch = batch_mapper_fn(batch) + current_input_rows_processed += input_batch_len + if transformed_batch is not None: + nb_current_samples += batch_length(transformed_batch) + mini_batch_results.append(transformed_batch) + + if output_batch_size and nb_current_samples >= output_batch_size: + batch_to_write = concat_batches(mini_batch_results) + new_state = copy.copy(self.state) + new_state.intermediate_index += 1 + new_state.written_batches_index = batch_idx + new_state.input_rows_written = current_input_rows_processed + + files_path = self.output_dataset_config.write_batch( + batch_to_write, + (self.state.intermediate_index, iteration_index), + metadata=metadata, + state_checkpoint=new_state, + ) + self.logger.info( + f"Following files has been written: \n {files_path}" + ) + output_written_paths.extend(files_path) + nb_current_samples = 0 + # only update state on successful write + self.state = new_state + # TODO deal with writting state to checkpoint here + mini_batch_results = [] + gc.collect() + + if len(mini_batch_results) > 0: + batch_to_write = concat_batches(mini_batch_results) + new_state = copy.copy(self.state) + new_state.intermediate_index += 1 + new_state.written_batches_index = batch_idx + new_state.input_rows_written = current_input_rows_processed + files_path = self.output_dataset_config.write_batch( + batch_to_write, + (self.state.intermediate_index, iteration_index), + metadata=metadata, + state_checkpoint=new_state, + ) + self.logger.info(f"Following files has been written: \n {files_path}") + output_written_paths.extend(files_path) + self.state = None + + return output_written_paths + + def name(self): + name = ( + self.get_custom_metadata().get("name") + or str(self.input_dataset_config.input_file)[-250:] + ) + return name + + def checkpoint( + self, *args, **kwargs + ) -> tp.Optional[submitit.helpers.DelayedSubmission]: + if self.state is not None: + return submitit.helpers.DelayedSubmission(self) + return None + + +class BatchMapper(ABC): + """ + Abstract class that could be used to structure a Statefull transformation. + It takes typically a config for the init, loads some models in init. + Then there a `__call__` method to implement that transform Batch to Batch (potentially small) + + Example: + + ... import librosa + >>> import whisper + >>> import torch + >>> import pandas as pd + + + >>> class LIDPredictor(BatchMapper): + ... def __init__(self, model_config) -> None: + ... self.model_config = model_config + ... self.model = whisper.load_model(model_config.get("name", "large-v2")) + + ... def row_mapper(self, wav: np.ndarray) -> tp.Dict[str, float]: + ... audio = whisper.pad_or_trim(wav, length=480000) + ... mel = whisper.log_mel_spectrogram(audio).to(self.model.device) + ... _, probs = self.model.detect_language(mel) + ... return probs + + ... def __call__(self, batch: BatchType) -> pd.DataFrame: + if not isinstance(batch, pd.DataFrame): + ... batch = batch.to_pandas() + ... inp_col = self.model_config.get("input_column", "wavform") + ... out_col = self.model_config.get("output_column", "lid_proba") + + ... with torch.inference_mode(): + ... batch[f"out_col{output_suffix}"] = batch[inp_col].apply(self.row_mapper) + + ... return batch + + """ + + def __init__(self, config, *args, **kwargs): + self.config = config + # self.model = load_model(config.path) + + @abstractmethod + def __call__(self, batch: tp.Optional[BatchType]) -> tp.Optional[BatchType]: + if batch is None: + return None + # more complex case we want to do + # return self.model(pandas_to_torch(batch)) + return batch + + def clear_memory(self) -> None: + pass + # import gc + # if self.model is not None: + # try: + # self.model.cpu() + # self.model = None + # except: + # pass + # gc.collect() + # torch.cuda.empty_cache() + + +@dataclass +class IOConfigWithBatchMapper(PartitionedDataMapperConfig): + mapper_config: tp.Any = None + + +def stopes_data_mapper( + requirements: Requirements, + metadata: tp.Optional[tp.Dict[str, tp.Any]] = None, + shards_to_skip: tp.Optional[tp.Dict[int, tp.List[Path]]] = None, + state: tp.Optional[PartitionedDataMapperState] = None, +): + """ + A decorator function that can be used to wrap a BatchMapper class into PartitionedDataMapper + + Args: + - requirements (Requirements): + - metadata (tp.Optional[tp.Dict[str, tp.Any]], optional):. Metadata to attach to resulting dataset. Defaults to None. + + Example: + + @stopes_data_mapper(requirements=Requirements()) + class LIDPredictor(BatchMapper): + ... + def __init__(self, ...): + ... + + def __call__(self, batch): + ... + return transformed_batch + + lid_stopes_module = LIDPredictor(input_config, output_config, mapper_config={"model_name": "large_v2"}) + + launcher = Launcher( + cache=None, + config_dump_dir=..., + log_folder=..., + cluster="slurm", + ) + + results = await launcher.schedule(lid_stopes_module) + + """ + + def decorator_mapper(mapper_cls: BatchMapper): + class WrappedMapper(PartitionedDataMapper): + def get_batch_mapper( + self, + ): + return mapper_cls( + self.config.mapper_config, + ) + + def requirements(self): + return requirements + + def get_custom_metadata(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: + if metadata is None: + return {} + return metadata or {} + + def config_assembler( + input_dataset_config: InputShardingConfig, + output_dataset_config: OutputDatasetConfig, + mapper_config: tp.Any, # should be a BatchedMapper dataclass config + ): + io_config_with_mapper = IOConfigWithBatchMapper( + input_dataset_config=input_dataset_config, + output_dataset_config=output_dataset_config, + mapper_config=mapper_config, + ) + + return WrappedMapper( + io_config_with_mapper, + state=state, + shards_to_skip=shards_to_skip, + ) + + return config_assembler + + return decorator_mapper diff --git a/stopes/modules/preprocess/sonar_text_embedding.py b/stopes/modules/preprocess/sonar_text_embedding.py new file mode 100644 index 0000000..8fd83c8 --- /dev/null +++ b/stopes/modules/preprocess/sonar_text_embedding.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import gc +import typing as tp +from dataclasses import dataclass +from functools import partial + +import numpy as np +import pandas as pd +import pyarrow as pa +import torch +from fairseq2.assets.error import AssetError +from retrying import retry +from sonar.inference_pipelines.text import ( + EmbeddingToTextModelPipeline, + TextToEmbeddingModelPipeline, +) + +from stopes.core.stopes_module import Requirements +from stopes.modules.partitioned_data_mapper import ( + BatchMapper, + PartitionedDataMapper, + PartitionedDataMapperConfig, +) +from stopes.utils.arrow_utils import ( + apply_on_nested_array, + apply_over_groups, + numpy_to_fixed_size_pyarrow_array, + pyarrow_fixed_size_array_to_numpy, +) +from stopes.utils.sharding.abstract_shards import BatchFormat, batch_to_table +from stopes.utils.sharding.parquet_shards import ParquetOutputConfig + +fairse2_asset_loading_retry = retry( + retry_on_exception=lambda exception: isinstance(exception, (AssetError, IOError)), + stop_max_attempt_number=20, + wait_random_min=1000, + wait_random_max=30_000, +) + + +@dataclass +class LangColumnConfig: + column: str + lang_value: tp.Optional[str] = None + lang_column: tp.Optional[str] = None + suffix: str = "_sonar_emb" + + """ + 1. For multi-lingual setup, provide `lang_column` a column of input dataset containing lang values of `column` + 2. For mono-lingual setup, provide directly the language with `lang_value` (like "eng_Latn") + """ + + def __post_init__(self): + if (self.lang_value is None) == (self.lang_column is None): + raise ValueError( + "Exactly one param out of `lang` and `lang_col` should be provided" + ) + + assert len(self.suffix) > 0 + + +@dataclass +class SonarTextEmbedderConfig: + column_config: tp.List[LangColumnConfig] + model_name: str = "text_sonar_basic_encoder" + tokenizer_name: tp.Optional[str] = None + device: str = "cuda" + batch_size: int = 10 + dtype: tp.Optional[str] = None # "float32" + """ + This config allow to handle multiple columns and each of columns can be multilingual. + It also supports columns with nested text values (a list of sentences per row) + in which case it returns nested embeddings column. + """ + + def __post_init__(self): + self.tokenizer_name = self.tokenizer_name or self.model_name + + +class _MonoLangMapperInterface(BatchMapper): + @torch.inference_mode() + def _apply_on_simple_column( + self, + col: tp.Union[pa.Array, pa.ChunkedArray], + lang_value: str, + ) -> tp.Union[pa.Array, pa.ChunkedArray]: + ... + + def _apply_on_unique_lang_table( + self, table: pa.Table, config: LangColumnConfig + ) -> pa.Table: + if config.lang_column: + assert ( + len(table[config.lang_column].unique()) == 1 + ), "this method should be called only for unique lang values" + lang_value = table[config.lang_column][0].as_py() + else: + lang_value = config.lang_value + + try: + col = table[config.column] + except KeyError: + # `table.flatten()` allows to access fields from stuct directly + # with the following name: `{column_name}.{struct_field_name}` + col = table.flatten()[config.column] + + new_column = apply_on_nested_array( + partial(self._apply_on_simple_column, lang_value=lang_value), + col, + ) + new_name = f"{config.column}{config.suffix}" + return table.append_column(new_name, new_column) + + def __call__( + self, table: tp.Optional[tp.Union[pa.Table, pd.DataFrame]] + ) -> tp.Optional[pa.Table]: + if table is None: + return None + + table = batch_to_table(table) + + for current_config in self.config.column_config: + table = apply_over_groups( + table, + [ + current_config.lang_column + ], # note that if `current_config.lang_column` is None function will be applied on the full table + partial(self._apply_on_unique_lang_table, config=current_config), + ) + + return table + + +class SonarTextBatchEmbedder(_MonoLangMapperInterface): + def __init__(self, config: SonarTextEmbedderConfig) -> None: + super().__init__(config) + self.dtype = np.dtype(self.config.dtype) if self.config.dtype else None + self.pipeline = fairse2_asset_loading_retry(TextToEmbeddingModelPipeline)( + self.config.model_name, + self.config.tokenizer_name, + device=torch.device(self.config.device), + ) + + @torch.inference_mode() + def _apply_on_simple_column( + self, + col: tp.Union[pa.Array, pa.ChunkedArray], + lang_value: str, + ) -> pa.FixedSizeListArray: + + assert pa.types.is_string(col.type) or pa.types.is_binary( + col.type + ), f"unsupported dtype: {col.type}" + inp = col.to_pandas() + order = np.argsort([len(x) for x in inp]) + ordered_inp = inp.iloc[order].to_list() + emb = ( + self.pipeline.predict( + input=ordered_inp, + source_lang=lang_value, + batch_size=self.config.batch_size, + ) + .cpu() + .numpy() + ) + torch.cuda.empty_cache() + gc.collect() + if self.dtype: + emb = emb.astype(self.dtype) + inv_order = np.argsort(order) + emb = emb[inv_order] + return numpy_to_fixed_size_pyarrow_array(emb) + + +@dataclass +class SonarEmbeddingDecodingConfig(SonarTextEmbedderConfig): + model_name: str = "text_sonar_basic_decoder" + generator_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None + """ + for generator kwargs, please refer to `BeamSearchSeq2SeqGenerator` in fairseq2 : + https://github.com/facebookresearch/fairseq2/blob/main/src/fairseq2/generation/beam_search.py + + """ + + # dont forget to rename column suffix + # suffix: str = "_sonar_emb" + def __post_init__(self): + super().__post_init__() + for col_config in self.column_config: + if col_config.suffix == "_sonar_emb": + print( + "Using default ENCODER suffix '_sonar_emb'," + f" condsider replace it for column = {col_config.column}" + ) + + +class SonarEmbeddingToTextMapper(_MonoLangMapperInterface): + def __init__(self, config: SonarEmbeddingDecodingConfig) -> None: + super().__init__(config) + self.pipeline = fairse2_asset_loading_retry(EmbeddingToTextModelPipeline)( + self.config.model_name, + self.config.tokenizer_name, + device=torch.device(self.config.device), + ) + + @torch.inference_mode() + def _apply_on_simple_column( + self, + col: tp.Union[pa.Array, pa.ChunkedArray], + lang_value: str, + ) -> tp.Union[pa.Array, pa.ChunkedArray]: + np_array = pyarrow_fixed_size_array_to_numpy(col) + + text = self.pipeline.predict( + inputs=torch.from_numpy(np_array), + target_lang=lang_value, + batch_size=self.config.batch_size, + **(self.config.generator_kwargs or {}), + ) + return pa.array(text, type=pa.string()) + + +@dataclass +class SonarTextEmbedderStopesConfig(PartitionedDataMapperConfig): + sonar_config: SonarTextEmbedderConfig + + def __post_init__(self): + super().__post_init__() + self.input_dataset_config.batch_format = BatchFormat.ARROW + assert isinstance( + self.output_dataset_config, ParquetOutputConfig + ), "Embedding can be serialized only in Parquet" + + +class SonarTextEmbedderStopes(PartitionedDataMapper): + def __init__(self, config: SonarTextEmbedderStopesConfig) -> None: + super().__init__(config) + + def get_custom_metadata(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: + return {} + + def requirements(self): + return Requirements( + nodes=1, + mem_gb=30, + tasks_per_node=1, + gpus_per_node=int(bool(self.config.sonar_config.device.startswith("cuda"))), + cpus_per_task=( + 4 if self.config.sonar_config.device.startswith("cuda") else 20 + ), + ) + + def get_batch_mapper(self): + return SonarTextBatchEmbedder(self.config.sonar_config) diff --git a/stopes/modules/speech/audio_load_utils.py b/stopes/modules/speech/audio_load_utils.py index 73210da..4a621ab 100644 --- a/stopes/modules/speech/audio_load_utils.py +++ b/stopes/modules/speech/audio_load_utils.py @@ -12,12 +12,21 @@ import numpy as np import torch -from torch import multiprocessing from stopes.modules.speech import utils as speech_utils log = logging.getLogger(__name__) -multiprocessing.set_start_method("spawn", force=True) + + +@contextlib.contextmanager +def spawn_mp(num_processses: int): + from torch import multiprocessing + + start_method = multiprocessing.get_start_method() + multiprocessing.set_start_method("spawn", force=True) + yield multiprocessing.Pool(num_processses) + + multiprocessing.set_start_method(start_method) def parallel_audio_read( @@ -28,8 +37,8 @@ def parallel_audio_read( num_process: tp.Optional[int] = 4, chunksize: int = 16, sampling_factor: int = 16, - read_audio_func: tp.Callable = speech_utils.read_audio, collapse_channels: bool = False, + **kwargs, ) -> tp.Iterator[Tuple[str, np.ndarray]]: """ Load audio from a given manifest file, using several processes. @@ -40,7 +49,8 @@ def parallel_audio_read( """ with contextlib.ExitStack() as stack: if num_process is None or num_process > 1: - pool = stack.enter_context(multiprocessing.Pool(num_process)) + num_process = num_process or 4 # default number of processes + pool = stack.enter_context(spawn_mp(num_process)) # chunksize=16, to send segments from the same file to the same worker. # Note this optimization only work if the input file is sorted per segment, # which in the case of mining only work for the source segments. @@ -58,8 +68,8 @@ def parallel_audio_read( # pytorch modifies the multiprocessing behaviour to optimize serializing Tensors. # This causes non-deterministic "No space left on device" errors when we return a Tensor. # In order to avoid that, we are using numpy arrays as return values. - read_audio_func=read_audio_func, collapse_channels=collapse_channels, + **kwargs, ) yield from pool_imap(_load, lines) @@ -71,8 +81,8 @@ def load_audio( line: str, sampling_factor: int = 16, as_numpy: bool = False, - read_audio_func: tp.Callable = speech_utils.read_audio, collapse_channels: bool = False, + **kwargs, ) -> tp.Tuple[str, tp.Union[torch.Tensor, np.ndarray]]: """ Load audio from a TSV-line where column at `column_offset` contains audio info @@ -91,7 +101,10 @@ def load_audio( if isinstance(audio_meta, speech_utils.Audio): # mp3 files need to be fully read before we can extract a segment. # But segments are sorted so we should hit several time the same file. - wav = read_audio_func(audio_meta.path, audio_meta.sampling_factor * 1000) + wav = speech_utils.read_audio( + audio_meta.path, + audio_meta.sampling_factor * 1000, + ) if len(wav.shape) > 1: wav = wav[:, audio_meta.start : audio_meta.end] else: @@ -99,7 +112,12 @@ def load_audio( elif isinstance(audio_meta, speech_utils.AudioBytes): wav = audio_meta.load() elif isinstance(audio_meta, speech_utils.Text): - wav = read_audio_func(audio_meta.content, sampling_factor * 1000) + wav = speech_utils.read_audio( + audio_meta.content, + sampling_factor * 1000, + kwargs.get("start_frame", None), + kwargs.get("end_frame", None), + ) if gpu and torch.cuda.is_available(): wav = wav.cuda() if fp16: diff --git a/stopes/modules/speech/audio_zip.py b/stopes/modules/speech/audio_zip.py index 4016849..ec70b4e 100644 --- a/stopes/modules/speech/audio_zip.py +++ b/stopes/modules/speech/audio_zip.py @@ -12,31 +12,38 @@ import zipfile from pathlib import Path +import hydra import torch import torchaudio from tqdm import tqdm from stopes.core import Requirements, StopesModule, utils -from stopes.modules.speech import speech_units +from stopes.core.utils import open as stopes_open +from stopes.modules.speech import audio_load_utils from stopes.modules.speech import utils as speech_utils +from stopes.utils.config_utils import parse_hydra_list from stopes.utils.data_utils import DictWithOverrides -from stopes.utils.shards import ( - Shard, - make_one_shard, - make_shards, +from stopes.utils.sharding.text_shards import ( + TextShard, + make_one_text_shard, + make_text_file_shards, parse_header, resolve_output, ) -log = logging.getLogger("stopes.speech.audio_zip") - @dataclasses.dataclass class AudioZipConfig: - tsv_file: Path - # column index (0,1) of column name ("src_audio", "tgt_audio",..) - column: tp.Union[int, str] + input_file: Path + + # column index (0,1) or column name ("src_audio", "tgt_audio",..) + # Accepted values: + # - column index: 0, 1, ... + # - column headers: "src_audio", "tgt_audio",.. + # - List of column indicdes: '[0,1]', '[2,3]', ... + # - List of column headers: '[es_audio,en_audio]', '[src_audio,tgt_audio]', ... + column: tp.Any output_dir: Path output_prefix: tp.Optional[str] = None sample_rate: int = 16_000 @@ -48,9 +55,17 @@ class AudioZipConfig: # nshards = no. of shards to split the inputs nshards: int = 1 + # Whether to perform internal sorting to speed up audio reading + # (But will change the order of audios in the ip files) + sorted: bool = True + # Whether to provide zip for duplicate segments no_duplicate: bool = True + # Custom reading logic in audio zip, applied to each shard at runtime + # Default is None, meaning using the builtin reader of the shard + reader_wrapper: tp.Optional[str] = None + class AudioZipModule(StopesModule): """ @@ -65,7 +80,7 @@ class AudioZipModule(StopesModule): Example command to run audiozip from a manifest file on content of the 3rd column = 2 (0 index) python -m stopes.modules +audio_zip=base \ - audio_zip.tsv_file=myfile.tsv \ + audio_zip.input_file=myfile.tsv \ audio_zip.output_zip=myoutput.zip \ audio_zip.column=2 \ launcher.cluster=debug @@ -75,8 +90,8 @@ class AudioZipModule(StopesModule): def __init__(self, config: AudioZipConfig, **kwargs): super().__init__(config, AudioZipConfig) - if not self.config.tsv_file.exists(): - raise ValueError(f"Input tsv_file not found: {self.config.tsv_file}") + if not self.config.input_file.exists(): + raise ValueError(f"Input input_file not found: {self.config.input_file}") self.output_dir = Path(self.config.output_dir).resolve() self.output_dir.mkdir(exist_ok=True, parents=True) @@ -84,36 +99,98 @@ def __init__(self, config: AudioZipConfig, **kwargs): self.config.output_prefix if self.config.output_prefix is not None and len(self.config.output_prefix) > 0 - else self.config.tsv_file.stem.replace(".tsv", "") + else self.config.input_file.stem.replace(".tsv", "") .replace(".gz", "") .strip() ) - if self.config.nshards > 1: - (self.output_dir / "workdir").mkdir(exist_ok=True) + + self.columns = parse_hydra_list(self.config.column) + assert ( + self.columns + ), f"Expect config.column to be a list or string, get {self.config.column}" + self.multi_column_mode = len(self.columns) > 1 + self.header = not str(self.columns[0]).isdecimal() + + self._prepare_tmp_dir() self.output_zip = self.output_dir / f"{self.output_prefix}_audio.zip" self.manifest_tsv = self.output_dir / f"{self.output_prefix}_zipped.tsv.gz" + self.logger = logging.getLogger("stopes.speech.audio_zip") self.kwargs = kwargs - self.header = ( - isinstance(self.config.column, str) and not self.config.column.isdecimal() + + def _prepare_tmp_dir(self): + """Prepare a temporary directory if needed""" + if self.config.nshards > 1 and ( + self.config.input_file.suffix.endswith(".gz") + or self.config.input_file.suffix.endswith(".zip") + ): + (self.output_dir / "workdir").mkdir(exist_ok=True) + + # In multi-column mode, we need a working dir to store intermediate file + if self.multi_column_mode: + (self.output_dir / "workdir").mkdir(exist_ok=True) + + def _combine_columns(self) -> Path: + """ + combine multiple audio columns into ones. Return the intermediate file with + one column containing all audio paths + """ + assert ( + self.columns + ), f"Expect config.column to be a list or string, get {self.config.column}" + + combined_file = ( + self.output_dir + / "workdir" + / (Path(self.config.input_file).name + ".columns_merged") ) + with ( + stopes_open(self.config.input_file) as reader, + stopes_open(combined_file, "a+") as writer, + ): + if self.header and isinstance(self.header, tp.Iterable): + cols = next(reader).rstrip("\n").split("\t") + col_offsets = [cols.index(c) for c in self.columns] + else: + col_offsets = [int(c) for c in self.columns] + for line in reader: + columns = line.rstrip("\t").split("\t") + for col in col_offsets: + writer.write(f"{ columns[col]}\n") + return combined_file def requirements(self) -> Requirements: return Requirements(timeout_min=1440) - def array(self) -> tp.List[Shard]: + def reader_wrapper(self) -> tp.Callable[..., tp.ContextManager]: + """Provide a reader wrapper around a Shard, optionally with extra logic""" + if self.config.reader_wrapper is None: + return lambda x, *unused, **kwargs: x + else: + return hydra.utils.get_method(self.config.reader_wrapper) + + def array(self) -> tp.List[TextShard]: + if self.multi_column_mode: + input_file: Path = self._combine_columns() + col = 0 + else: + input_file = self.config.input_file + col = self.columns[0] # type: ignore return list( - make_shards( - self.config.tsv_file, + make_text_file_shards( + input_file, nshards=self.config.nshards, - algo="sort", + algo="sort" if self.config.sorted else "chunk", cache_dir=self.output_dir / "workdir", header=self.header, sep="\t", - col=self.config.column, + col=col, no_duplicate=self.config.no_duplicate, ) ) + def is_zip_complete(self, shard: TextShard, manifest: Path): + """ """ + def run( self, iteration_value: tp.Any = None, iteration_index: int = 0 ) -> tp.Tuple[Path, Path]: @@ -123,30 +200,68 @@ def run( # module from command line: python -m stopes.modules.speech.audio_zip ... # In this case we create a dummy Shard object with index = None if shard is None: - cols = parse_header(self.config.tsv_file, "\t") if self.header else None - shard = Shard(self.config.tsv_file, cols=cols, sep="\t") + cols = parse_header(self.config.input_file, self.header, "\t") + shard = TextShard( + input_file=self.config.input_file, columns=cols, sep="\t", filter=None + ) + + assert isinstance( + shard, TextShard + ), "Each audio zip task expects input to be sharded, or convert to a Shard" + output_zip = resolve_output(shard, Path(self.output_zip), suffix=".zip") output_manifest = resolve_output(shard, Path(self.manifest_tsv), suffix=".tsv") + error_manifest = resolve_output(shard, Path(self.manifest_tsv), suffix=".err") assert output_zip and output_manifest - column_offset = shard.resolve_column_index(self.config.column) + # IF there are more than one audio columns, we are actually handling the + # combined tsv file here which is a one-column manifest + if self.multi_column_mode: + column_offset = 0 + else: + assert isinstance( + self.config.column, (int, str) + ), f"invalid column setting: {self.config.column} (Expect int or str)" + column_offset = shard.resolve_column_index(self.config.column) sample_rate = self.config.sample_rate audio_format = self.config.audio_format - with shard as progress: - zip_comment = ( - f"Audio segments extracted from {shard.input_file} [{shard.index}]" - ) - with AudioZipWriter( - output_zip=output_zip, - manifest_path=output_manifest, - audio_format=audio_format, - add_header=False, - zip_comment=zip_comment, - ) as writer: - lines = iter(progress) + # Before running, check if the manifest exists and skip if the content is full + # (Useful in rerunning the job) + skipped = False + if validate( + output_manifest, + output_zip, + self.config.audio_format, + self.config.output_validation_token, + ): + self.logger.info(f"Skip shard #{iteration_index}") + skipped = True + + if not skipped: + with contextlib.ExitStack() as stack: + writer = stack.enter_context( + AudioZipWriter( + output_zip=output_zip, + manifest_path=output_manifest, + audio_format=audio_format, + add_header=False, + zip_comment=f"Audio segments extracted from {shard.input_file} [{shard.index}]", + ) + ) + reader = stack.enter_context( # type: ignore + self.reader_wrapper()( # type: ignore + shard, + column_offset=column_offset, + logger=self.logger, + **self.kwargs, + ) + ) + err = stack.enter_context(stopes_open(error_manifest, "wt")) # type: ignore + lines = iter(reader) + for line, audio in tqdm( - speech_units.parallel_audio_read( + audio_load_utils.parallel_audio_read( lines, column_offset, sampling_factor=int(sample_rate / 1000), @@ -154,25 +269,38 @@ def run( ), unit="segment", ): - columns = line.rstrip("\n").split("\t") - audio = torch.tensor(audio, dtype=torch.float) - metadata_column_values = [] - if self.config.store_num_frames: - metadata_column_values.append(str(audio.size(-1))) - if self.config.store_input_line: - metadata_column_values.append(line.rstrip("\n")) - writer.append( - audio=audio, - sampling_rate=sample_rate, - filepath=f"{columns[column_offset]}.{audio_format}", - metadata_column_values=metadata_column_values, - ) - validate( - output_manifest, - output_zip, - self.config.audio_format, - self.config.output_validation_token, - ) + try: + line, audio = audio_load_utils.load_audio( + column_offset=column_offset, + gpu=False, + fp16=False, + line=line, + sampling_factor=int(sample_rate / 1000), + ) + columns = line.rstrip("\n").split("\t") + audio = torch.tensor(audio, dtype=torch.float) + metadata_column_values = [] + if self.config.store_num_frames: + metadata_column_values.append(str(audio.size(-1))) + if self.config.store_input_line: + metadata_column_values.append(line.rstrip("\n")) + writer.append( + audio=audio, + sampling_rate=sample_rate, + filepath=f"{columns[column_offset]}.{audio_format}", + metadata_column_values=metadata_column_values, + ) + except Exception: # type: ignore + self.logger.warning(f"Error in line {line}") + err.write(line) + + if not validate( + output_manifest, + output_zip, + self.config.audio_format, + self.config.output_validation_token, + ): + self.logger.warning(f"Shard {iteration_index} is not complete") return output_zip, output_manifest def name(self): @@ -203,11 +331,17 @@ def validate( audio_zip: Path, audio_format: str, output_validation_token: tp.Optional[bool] = False, -) -> None: - with zipfile.ZipFile(audio_zip) as z: - num_audio_files = len( - [i for i in z.infolist() if i.filename.endswith(f".{audio_format}")] - ) +) -> bool: + if not Path(audio_zip).exists(): + return False + + try: + with zipfile.ZipFile(audio_zip) as z: + num_audio_files = len( + [i for i in z.infolist() if i.filename.endswith(f".{audio_format}")] + ) + except IOError as exc: + return False num_lines = 0 for line in utils.open(manifest): @@ -218,12 +352,13 @@ def validate( audio.load() num_lines += 1 - assert ( - num_lines == num_audio_files - ), f"Found {num_lines} lines in {manifest}, but {num_audio_files} audio files in {audio_zip}" - if output_validation_token: - # persist validation token - open(f"{manifest}.validated-ok", "w").close() + if num_lines != num_audio_files and num_lines > 0: + return False + else: + if output_validation_token: + # persist validation token + open(f"{manifest}.validated-ok", "w").close() + return True class PostAudioZipModule(AudioZipModule): @@ -250,16 +385,9 @@ def __init__( ): super().__init__(config) - # post audiozip works with columns as a list - if type(self.config.column) == str: - self.columns = self.config.column.strip().strip("[]").split(",") - elif type(self.config.column) == int: - self.columns = [self.config.column] # type: ignore[list-item] - else: - raise ValueError("AudioZip config only accepts column of type int or str") - + self.columns = parse_hydra_list(self.config.column) if len(intermediate_zips) <= 1: - log.warning( + self.logger.warning( "This module is called within a pipeline with one shard." "Normally this means the mining file is not big enough." "Your pipeline might be simpler without the PostAudioZip module." @@ -270,13 +398,13 @@ def __init__( def array(self): # Make PostAudioZip a virtual array module to access function 'resolve_column_index'. # TODO: Put `resolve_column_index()` to a general utils module - header = isinstance(self.columns[0], str) and not self.columns[0].isdecimal() - return make_one_shard(self.config.tsv_file, header, sep="\t") + header = isinstance(self.columns[0], str) and not self.columns[0].isdecimal() # type: ignore + return make_one_text_shard(self.config.input_file, header, sep="\t") def run(self, iteration_value: tp.Any = None, iteration_index: int = 0): input_file = iteration_value - assert isinstance(input_file, Shard) - column_offsets = [input_file.resolve_column_index(col) for col in self.columns] + assert isinstance(input_file, TextShard) + column_offsets = [input_file.resolve_column_index(col) for col in self.columns] # type: ignore # combine all zips into a zip files saved in self.output_zip_{0000x}) compacted_zip_files = self.combine_zips() @@ -496,12 +624,14 @@ def append( sampling_rate: Sampling rate. filepath: The original audio file path if applicable, when it is None, a random name would be used. metadata_column_values: A list of metadata values you want to append for the manifest file, - the size of `metadata_column_values` should be equal to the size of `self.metadata_column_names`. when add_header is True + the size of `metadata_column_values` should be equal to the size of `self.metadata_column_names` + when add_header is True """ if self.add_header and metadata_column_values is not None: assert len(metadata_column_values) == len(self.metadata_column_names), ( f"The number of metadata values should be equal to the number of metadata columns, " - f"got number of metadata values: {len(metadata_column_values)}, number of metadata columns: {len(self.metadata_column_names)}" + f"got number of metadata values: {len(metadata_column_values)}, " + f"number of metadata columns: {len(self.metadata_column_names)}" ) if len(audio.size()) == 1: # torchaudio.save expects to have a 2D tensor @@ -517,6 +647,6 @@ def append( # Simple command line to run the audio zipping without post-processing logging.basicConfig(level=logging.INFO) - # python -m stopes.modules.speech.audio_zip --tsv_file mining_results.tsv --column 1 --output_dir /myoutput + # python -m stopes.modules.speech.audio_zip --input_file mining_results.tsv --column 1 --output_dir /myoutput cfg = func_argparse.single_main(AudioZipConfig) AudioZipModule(cfg).run() diff --git a/stopes/modules/speech/speech_units.py b/stopes/modules/speech/speech_units.py index 7d57612..f26682b 100644 --- a/stopes/modules/speech/speech_units.py +++ b/stopes/modules/speech/speech_units.py @@ -19,9 +19,13 @@ from stopes import hub from stopes.core import Requirements, StopesModule, utils -from stopes.modules.speech.audio_load_utils import load_audio, parallel_audio_read +from stopes.modules.speech.audio_load_utils import parallel_audio_read from stopes.speech.tokenizers import SpeechTokenizer -from stopes.utils.shards import Shard, make_shards, resolve_output +from stopes.utils.sharding.text_shards import ( + TextShard, + make_text_file_shards, + resolve_output, +) log = logging.getLogger("stopes.speech.units") multiprocessing.set_start_method("spawn", force=True) @@ -101,7 +105,7 @@ def __init__(self, config: SpeechUnitsConfig): super().__init__(config) self.output_dir = Path(self.config.output_dir) self.output_dir.mkdir(exist_ok=True) - self._current_progress: Optional[Shard] = None + self._current_progress: Optional[TextShard] = None @functools.cached_property def tokenizer(self) -> SpeechTokenizer: @@ -119,12 +123,12 @@ def requirements(self) -> Requirements: constraint=None if self.tokenizer.config["fp16"] else "volta32gb", ) - def array(self) -> List[Shard]: + def array(self) -> List[TextShard]: header = ( isinstance(self.config.column, str) and not self.config.column.isdecimal() ) return list( - make_shards( + make_text_file_shards( self.config.shards, cache_dir=self.output_dir, header=header, @@ -144,7 +148,7 @@ def run( iteration_value: Optional[Any] = None, iteration_index: Optional[int] = None, ) -> Any: - assert isinstance(iteration_value, Shard) + assert isinstance(iteration_value, TextShard) shard = iteration_value self._current_progress = shard lang = getattr(self.tokenizer.config, "lang", None) @@ -198,7 +202,7 @@ def run( def checkpoint( self, - iteration_value: Shard, + iteration_value: TextShard, iteration_index: int, **kwargs: Any, ) -> submitit.helpers.DelayedSubmission: diff --git a/stopes/modules/speech/wav2vec/asr.py b/stopes/modules/speech/wav2vec/asr.py index 17088ad..c5ef30d 100644 --- a/stopes/modules/speech/wav2vec/asr.py +++ b/stopes/modules/speech/wav2vec/asr.py @@ -6,9 +6,7 @@ # # Different modules for speech recognition / speech-to-text import dataclasses -import importlib import logging -import sys import typing as tp from pathlib import Path @@ -26,7 +24,12 @@ from stopes.speech.asr.wav2vec import decoder as wave2vec_decoder from stopes.speech.asr.wav2vec.base_decoder import BaseDecoder from stopes.speech.asr.wav2vec.decoder_config import FlashlightDecoderConfig -from stopes.utils.shards import Shard, make_shards, parse_header, resolve_output +from stopes.utils.sharding.text_shards import ( + TextShard, + make_text_file_shards, + parse_header, + resolve_output, +) @dataclasses.dataclass @@ -78,7 +81,7 @@ def __init__(self, config: ASRConfig, **kwargs): ) self.logger = logging.getLogger("stopes.asr") self.kwargs = kwargs - self._current_progress: tp.Optional[Shard] = None + self._current_progress: tp.Optional[TextShard] = None def load_model_and_task(self) -> tp.Tuple[tp.List[FairseqModel], FairseqTask]: """Load a Wav2vec Encoder and the ASR task""" @@ -104,9 +107,9 @@ def requirements(self) -> Requirements: cpus_per_task=int(self.config.cpus_per_task), ) - def array(self) -> tp.List[Shard]: + def array(self) -> tp.List[TextShard]: return list( - make_shards( + make_text_file_shards( self.config.shards, nshards=self.config.nshards, header=self.header, @@ -193,8 +196,10 @@ def run( assert Path( self.config.shards ).is_file(), "Direct call of run() only works with a single shard" - cols = parse_header(self.config.shards, "\t") if self.header else None - shard = Shard(self.config.shards, cols=cols, sep="\t") + cols = parse_header(self.config.shards, self.header, "\t") + shard = TextShard( + input_file=self.config.shards, columns=cols, sep="\t", filter=None + ) self._current_progress = shard # Set up I/O variables @@ -237,7 +242,7 @@ def run( def checkpoint( self, - iteration_value: Shard, + iteration_value: TextShard, iteration_index: int, **kwargs: tp.Any, ) -> submitit.helpers.DelayedSubmission: diff --git a/stopes/modules/speech/whisper.py b/stopes/modules/speech/whisper.py index 7c74784..002ae39 100644 --- a/stopes/modules/speech/whisper.py +++ b/stopes/modules/speech/whisper.py @@ -17,8 +17,14 @@ import stopes.modules.speech.utils as speech_utils from stopes.core.stopes_module import Requirements, StopesModule +from stopes.modules.speech.audio_load_utils import load_audio from stopes.modules.speech.speech_units import parallel_audio_read -from stopes.utils.shards import Shard, make_shards, parse_header, resolve_output +from stopes.utils.sharding.text_shards import ( + TextShard, + make_text_file_shards, + parse_header, + resolve_output, +) logger = logging.getLogger("stopes.whisper") @@ -45,6 +51,7 @@ class WhisperConfig: nshards: int = 1 gpu: bool = True cpus_per_task: int = 4 + timeout_min: int = 180 class WhisperModule(StopesModule): @@ -59,11 +66,11 @@ def __init__(self, config: WhisperConfig): self.header = ( isinstance(self.config.column, str) and not self.config.column.isdecimal() ) - self._current_progress: tp.Optional[Shard] = None + self._current_progress: tp.Optional[TextShard] = None - def array(self) -> tp.List[Shard]: + def array(self) -> tp.List[TextShard]: return list( - make_shards( + make_text_file_shards( self.config.shards, nshards=self.config.nshards, header=self.header, @@ -77,7 +84,7 @@ def requirements(self) -> Requirements: return Requirements( gpus_per_node=int(self.config.gpu), cpus_per_task=int(self.config.cpus_per_task), - timeout_min=180, + timeout_min=int(self.config.timeout_min), ) def get_audio(self, infile: str, ts_start: int, ts_end: int) -> torch.Tensor: @@ -104,8 +111,21 @@ def get_lines( line = line.rstrip() line = line.split("\t")[column_offset] infile, ts_start, ts_end, _ = speech_utils.parse_audio_deprecated(line) + # If no ts_start and ts_end provided, + # return ts_start = 0 and ts_end = len(wav) to read the full audio instead + if not (ts_start and ts_end) and not self.config.longest_segment: + ts_start = 0 + line, wav = load_audio( + column_offset, + gpu=False, + fp16=True, + line=line, + sampling_factor=16, + collapse_channels=True, + ) + ts_end = len(wav) assert ( - ts_start and ts_end + ts_start is not None and ts_end is not None ), f"Cannot parse timetamp from audio segment info: {line}" ts_start = int(ts_start) ts_end = int(ts_end) @@ -166,7 +186,9 @@ def run( self.config.shards ).is_file(), "Direct call of run() only works with a single shard" cols = parse_header(self.config.shards, "\t") if self.header else None - shard = Shard(self.config.shards, cols=cols, sep="\t") + shard = TextShard( + input_file=self.config.shards, columns=cols, sep="\t", filter=None + ) self._current_progress = shard assert ( @@ -178,9 +200,11 @@ def run( out_file ), f"Cannot determine the output file name for {shard.input_file} (shard #{shard.index})" column_offset = shard.resolve_column_index(self.config.column) - with shard as f, open( - out_file, "a+" - ) as o, tempfile.TemporaryDirectory() as data_gym_cache: + with ( + shard as f, + open(out_file, "a+") as o, + tempfile.TemporaryDirectory() as data_gym_cache, + ): os.environ["DATA_GYM_CACHE_DIR"] = str(data_gym_cache) self.model = whisper.load_model(self.config.model) diff --git a/stopes/modules/tests/test_partitioned_data_mapper.py b/stopes/modules/tests/test_partitioned_data_mapper.py new file mode 100644 index 0000000..d2a5e6e --- /dev/null +++ b/stopes/modules/tests/test_partitioned_data_mapper.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import math +import os +import shutil +import tempfile +import unittest +from pathlib import Path +from typing import Any, Dict + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.parquet as pq + +from stopes.modules.partitioned_data_mapper import ( + PartitionedDataMapper, + PartitionedDataMapperConfig, +) +from stopes.utils.sharding.parquet_shards import ( + ParquetOutputConfig, + ParquetShardingConfig, +) +from stopes.utils.sharding.text_shards import TextOutputConfig, TextShardingConfig +from stopes.utils.tests.conftest import permutationally_equal_dataframes + +NUM_ROW_GROUPS = 15 +FULL_DS_SIZE = 3 * 10**5 + + +class IdentityPartitionedDataMapper(PartitionedDataMapper): + def get_batch_mapper(self): + return lambda batch: batch + + def requirements(self): + ... + + def get_custom_metadata(self, *args, **kwargs) -> Dict[str, Any]: + return {} + + +def generated_partitioned_parquet_dataset( + path: str, size: int, n_partitions: int = 5, seed: int = 123 +) -> None: + np_rs = np.random.RandomState(seed) + df = { + "int_col": np_rs.randint(0, 200, size), + "float_col": np.round(np_rs.randn(size), 10), + "bool_col": np_rs.randn(size) > 0, + "part_key": np.arange(size) % (n_partitions + 1), + } + + table = pa.Table.from_pydict(df) + + pq.write_to_dataset( + table, + path, + partition_cols=["part_key"] if n_partitions > 0 else None, + **{"max_rows_per_group": 2000, "min_rows_per_group": 1000}, + ) + df_pd = table.to_pandas() + df_pd.to_csv(os.path.join(path, ".tsv"), sep="\t", index=False) + + +class TestPartitionedDataMapper(unittest.TestCase): + _tmpdir: str + _tmp_parquet_ds_path: str + _tmp_parquet_single_path: str + + @classmethod + def setUpClass(cls) -> None: + cls._tmpdir = tempfile.mkdtemp() + cls._tmp_parquet_ds_path = os.path.join(cls._tmpdir, "test") + generated_partitioned_parquet_dataset( + cls._tmp_parquet_ds_path, size=FULL_DS_SIZE + ) + + cls._tmp_parquet_single_path = os.path.join(cls._tmpdir, "single_test.parquet") + generated_partitioned_parquet_dataset( + cls._tmp_parquet_single_path, size=FULL_DS_SIZE, n_partitions=0 + ) + + @classmethod + def tearDownClass(cls) -> None: + # Cleanup temp working dir. + if cls._tmpdir is not None: + shutil.rmtree(cls._tmpdir) # type: ignore + + def test_basic_parquet_to_parquet_walkthrough(self): + input_config = ParquetShardingConfig( + input_file=self._tmp_parquet_ds_path, + batch_size=100, + split_row_groups=True, + fragment_group_size=2, + filters_expr="pa.compute.and_kleene(pa.compute.greater(ds.field('part_key'), 2), pa.compute.greater(ds.field('float_col'), 0))", + columns=["part_key", "float_col", "bool_col"], + ) + output_config = ParquetOutputConfig( + os.path.join(self._tmpdir, "output"), + batch_size=10**3, + max_rows_per_file=600, + ) + mapper_config = PartitionedDataMapperConfig(input_config, output_config) + + dpdm = IdentityPartitionedDataMapper(mapper_config) + self.assertEqual(len(dpdm.array()), 15) # so this's clearly split by row groups + + output_paths = [ + y for i, frag in enumerate(dpdm.array()) for y in dpdm.run(frag, i) # type: ignore + ] + reloaded_dataset = pa.parquet.ParquetDataset(output_paths) + + reloaded_table = reloaded_dataset.read() + input_table = pa.parquet.ParquetDataset( + self._tmp_parquet_ds_path, + filters=pa.compute.and_kleene( + pa.compute.greater(ds.field("part_key"), 2), + pa.compute.greater(ds.field("float_col"), 0), + ), + ).read(columns=["part_key", "float_col", "bool_col"]) + + assert permutationally_equal_dataframes( + reloaded_table.to_pandas(), input_table.to_pandas() + ) + # check metadata + metadata = reloaded_table.schema.metadata + assert list(metadata.keys()) == [ + b"config", + b"batch_mapper_class", + b"batch_mapper_code", + b"username", + b"save_time", + b"previous_metadata", + ] + assert metadata[b"batch_mapper_class"] == b'{"": ""}' + assert ( + metadata[b"batch_mapper_code"] == b'" return lambda batch: batch\\n"' + ) + + assert len(output_paths) == 153 + + def test_basic_tsv_to_parquet_walkthrough(self): + input_config = TextShardingConfig( + input_file=os.path.join(self._tmp_parquet_ds_path, ".tsv"), + columns=["part_key", "float_col"], + filters_expr="pa.compute.greater(ds.field('float_col'), 0)", + sep="\t", + batch_size=211, + nb_shards=11, + header=True, + ) + + output_config = ParquetOutputConfig( + os.path.join(self._tmpdir, "output_from_tsv"), + keep_same_partitioning=False, + partition_columns=["part_key"], + ) + mapper_config = PartitionedDataMapperConfig(input_config, output_config) + + dpdm = IdentityPartitionedDataMapper(mapper_config) + self.assertEqual(len(dpdm.array()), 11) + + output_paths = [ + y for i, frag in enumerate(dpdm.array()) for y in dpdm.run(frag, i) # type: ignore + ] + self.assertEqual(len(output_paths), 11 * (5 + 1)) + + reloaded_table = ( + pa.parquet.ParquetDataset(output_paths).read_pandas().to_pandas() + ) + + input_table = ( + pa.parquet.ParquetDataset( + self._tmp_parquet_ds_path, + filters=pa.compute.greater(ds.field("float_col"), 0), + ) + .read_pandas(columns=["part_key", "float_col"]) + .to_pandas() + ) + + assert permutationally_equal_dataframes(reloaded_table, input_table) + + def test_basic_tsv_to_tsv_walkthrough(self): + input_config = TextShardingConfig( + input_file=os.path.join(self._tmp_parquet_ds_path, ".tsv"), + columns=["part_key", "float_col"], + filters_expr="pa.compute.greater(ds.field('float_col'), -2.)", + sep="\t", + batch_size=1112, + nb_shards=110, + header=True, + ) + output_config = TextOutputConfig( + os.path.join(self._tmpdir, "output_text"), + sep="\t", + ) + + mapper_config = PartitionedDataMapperConfig(input_config, output_config) + + dpdm = IdentityPartitionedDataMapper(mapper_config) + self.assertEqual(len(dpdm.array()), 110) + + output_paths = [ + y for i, frag in enumerate(dpdm.array()) for y in dpdm.run(frag, i) # type: ignore + ] + self.assertEqual(len(output_paths), 110) + self.assertTrue(all(str(path).endswith(".tsv") for path in output_paths)) + + reloaded_table = pd.concat( + [pd.read_csv(path, sep="\t", compression=None) for path in output_paths], + axis=0, + ) + + input_table = pd.read_csv( + os.path.join(self._tmp_parquet_ds_path, ".tsv"), + sep="\t", + usecols=["part_key", "float_col"], + ) + input_table = input_table[input_table["float_col"] > -2.0] + + assert permutationally_equal_dataframes(reloaded_table, input_table) + + def test_basic_parquet_to_tsv_walkthrough(self): + + input_config = ParquetShardingConfig( + input_file=self._tmp_parquet_ds_path, + split_row_groups=False, + columns=["part_key", "float_col", "bool_col"], + ) + out_dir = os.path.join(self._tmpdir, "output_text") + output_config = TextOutputConfig(out_dir, sep="\t", compression="gzip") + mapper_config = PartitionedDataMapperConfig(input_config, output_config) + + dpdm = IdentityPartitionedDataMapper(mapper_config) + shards = dpdm.array() + self.assertEqual(len(shards), 6) # so this's clearly split by row groups + + output_paths = [ + y for i, frag in enumerate(shards) for y in dpdm.run(frag, i) # type: ignore + ] + self.assertEqual(len(output_paths), 6) + self.assertTrue(all(str(path).endswith(".tsv.gzip") for path in output_paths)) + + reloaded_table = pd.concat( + [pd.read_csv(path, sep="\t", compression="gzip") for path in output_paths], + axis=0, + ) + + input_table = ( + pa.parquet.ParquetDataset( + self._tmp_parquet_ds_path, + ) + .read_pandas(columns=["part_key", "float_col", "bool_col"]) + .to_pandas() + ) + + assert len(list(Path(out_dir).glob(".text_output.*.state"))) == len(shards) + final_states = [output_config.reload_state(shard) for shard in shards] + assert all([s is not None for s in final_states]) + total_states_row = sum([s.input_rows_written for s in final_states]) # type: ignore + assert total_states_row == len(input_table) + + assert permutationally_equal_dataframes(reloaded_table, input_table) + + def test_limits_sharding_number(self): + + input_config_parq = ParquetShardingConfig( + input_file=self._tmp_parquet_ds_path, + split_row_groups=True, + take=3, + columns=["part_key", "float_col", "bool_col"], + ) + + self.assertEqual( + len(input_config_parq.make_shards()), 3 + ) # so this's clearly split by row groups + + input_config_text = TextShardingConfig( + input_file=os.path.join(self._tmp_parquet_ds_path, ".tsv"), + take=2, + filters_expr="pa.compute.greater(ds.field('float_col'), -2.)", + sep="\t", + batch_size=112, + nb_shards=110, + header=True, + ) + + self.assertEqual( + len(input_config_text.make_shards()), 2 + ) # so this's clearly split by row groups + + def test_parquet_skipping(self): + input_table = pa.parquet.ParquetDataset( + self._tmp_parquet_ds_path, + filters=pa.compute.and_kleene( + pa.compute.greater(ds.field("part_key"), 2), + pa.compute.greater(ds.field("float_col"), 0), + ), + ).read(columns=["part_key", "float_col", "bool_col"]) + + original_len = len(input_table) + to_skip = math.floor( + (original_len / NUM_ROW_GROUPS) / 3 + ) # 1/3 of a shard, we have 15 shards + + input_config = ParquetShardingConfig( + input_file=self._tmp_parquet_ds_path, + batch_size=100, + split_row_groups=True, + fragment_group_size=2, + filters_expr="pa.compute.and_kleene(pa.compute.greater(ds.field('part_key'), 2), pa.compute.greater(ds.field('float_col'), 0))", + columns=["part_key", "float_col", "bool_col"], + skip_n_rows_per_shard={0: to_skip}, + ) + out_dir = os.path.join(self._tmpdir, "output_parquet") + output_config = ParquetOutputConfig( + out_dir, + batch_size=10**3, + max_rows_per_file=600, + ) + mapper_config = PartitionedDataMapperConfig(input_config, output_config) + + dpdm = IdentityPartitionedDataMapper(mapper_config) + + shards = dpdm.array() + output_paths = [ + y for i, frag in enumerate(shards) for y in dpdm.run(frag, i) # type: ignore + ] + + reloaded_dataset = pa.parquet.ParquetDataset(output_paths) + + reloaded_table = reloaded_dataset.read() + + expected = original_len - to_skip + actual = len(reloaded_table) + assert expected == actual + + assert len(list(Path(out_dir).glob(".parquet_output.*.state"))) == len(shards) + final_states = [output_config.reload_state(shard) for shard in shards] + assert all([s is not None for s in final_states]) + total_states_row = sum([s.input_rows_written for s in final_states]) # type: ignore + assert total_states_row == len(input_table) diff --git a/stopes/modules/tests/test_speech_utils.py b/stopes/modules/tests/test_speech_utils.py index afcdf0d..c120ada 100644 --- a/stopes/modules/tests/test_speech_utils.py +++ b/stopes/modules/tests/test_speech_utils.py @@ -20,6 +20,7 @@ import stopes.modules.speech.postprocess as sprocess import stopes.modules.speech.utils as sputils +from stopes.modules.speech.audio_load_utils import load_audio from stopes.modules.speech.utils import Audio, AudioBytes, Text @@ -358,27 +359,11 @@ def test_parse_audio_bytes_with_resample(sample_audio, caplog): @pytest.mark.parametrize("gpu", [True, False]) @pytest.mark.parametrize("fp16", [True, False]) -@pytest.mark.parametrize("custom_read_func", [True, False]) -def test_load_audio_in_devices(sample_audio, gpu, fp16, custom_read_func, caplog): - # TODO: move load_audio to speech_utils - from stopes.modules.speech.speech_units import load_audio - from stopes.modules.speech.utils import read_audio - +def test_load_audio_in_devices(sample_audio, gpu, fp16, caplog): audio_path, num_frames = sample_audio[0] - fake_read_func = lambda *a: torch.zeros(1) # noqa - read_func = fake_read_func if custom_read_func else read_audio # ignore[assignment] line = f"{audio_path}|0|{num_frames}|16\tcoluimn2" - load_res = load_audio( - 0, gpu, fp16, line, sampling_factor=32, read_audio_func=read_func # type: ignore[arg-type] - ) + load_res = load_audio(0, gpu, fp16, line, sampling_factor=32) # type: ignore[arg-type] assert load_res[0] == line - if custom_read_func: - expected_wav = torch.zeros(1) - if gpu and torch.cuda.is_available(): - expected_wav = expected_wav.cuda() - if fp16: - expected_wav = expected_wav.half() - assert torch.equal(load_res[1], expected_wav) # type: ignore[arg-type] line_no_sample = f"{audio_path}|0|{num_frames}\tcoluimn2" with assert_warns( caplog, match="Sampling factor not present in file, using provided value." diff --git a/stopes/modules/tests/test_uromanize_cli.py b/stopes/modules/tests/test_uromanize_cli.py deleted file mode 100644 index 45b30fd..0000000 --- a/stopes/modules/tests/test_uromanize_cli.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -import os -from pathlib import Path - -import pytest - -import stopes -from stopes.core import utils -from stopes.modules.preprocess.uromanize_cli_module import ( - run_uroman_cli_standalone, - uromanize, -) -from stopes.pipelines.tests import test_configs - - -def _mock_input_file(tmp_path: Path) -> Path: - mock_input_file = tmp_path / "uroman_input" - with utils.open(mock_input_file, "w") as in_f: - in_f.write("ちょっとまってください\n" "アメリカ") - return mock_input_file - - -def test_run_uroman_cli_standalone(tmp_path: Path) -> None: - input_file = _mock_input_file(tmp_path) - output_file = tmp_path / "uroman_output" - run_uroman_cli_standalone(input_file, output_file, lang="xxx") - expected_output = "chottomattekudasai\namerika\n" - with utils.open(output_file) as out_f: - assert out_f.read() == expected_output - - -def test_uroman_cli_module(tmp_path: Path) -> None: - input_file = _mock_input_file(tmp_path) - output_dir = tmp_path / "output" - conf_path = ( - test_configs.STOPES / "pipelines" / "speech" / "conf" / "uromanization.yaml" - ) - cfg = test_configs.load_conf( - conf_path, - ( - f"lang=xxx", - f"input_file={input_file}", - f"output_dir={output_dir}", - ), - ) - module = stopes.core.StopesModule.build(cfg) - module.requirements() - output_file = module.run() - expected_output = "chottomattekudasai\namerika\n" - with utils.open(output_file) as out_f: - assert out_f.read() == expected_output - assert os.path.basename(out_f.name) == "uroman.uroman_input.xxx" - - -@pytest.mark.parametrize( - "text,expected_output", - [ - (["ちょっとまってください", "アメリカ"], ["chottomattekudasai", "amerika"]), - ([], []), - (None, []), - ], -) -def test_uromanize(text, expected_output): - assert uromanize(text) == expected_output diff --git a/stopes/pipelines/__init__.py b/stopes/pipelines/__init__.py index 0952fcc..1789616 100644 --- a/stopes/pipelines/__init__.py +++ b/stopes/pipelines/__init__.py @@ -3,3 +3,34 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + +from collections import UserDict + +from hydra.core.plugins import Plugins +from hydra.plugins.search_path_plugin import SearchPathPlugin + + +class StopesConfigRegistry(UserDict): + def __setitem__(self, key: str, item: str) -> None: + super().__setitem__(key, item) + + # Register key as config provider and value as config path + stopes_config_plugin = type( + "StopesConfigPath", + (SearchPathPlugin, object), + { + "manipulate_search_path": lambda self, search_path: search_path.append( + provider=key, path=item + ) + }, + ) + Plugins.instance().register(stopes_config_plugin) + + +config_registry = StopesConfigRegistry() + + +# Register the common pipeline configs in stopes +config_registry["stopes-common"] = "pkg://stopes/pipelines/conf" +config_registry["stopes-text-mining"] = "pkg://stopes/pipelines/bitext/conf" +config_registry["stopes-speech-mining"] = "pkg://stopes/pipelines/speech/conf" diff --git a/stopes/pipelines/bitext/conf/audio_zip/base.yaml b/stopes/pipelines/bitext/conf/audio_zip/base.yaml index 11293ca..d45c828 100644 --- a/stopes/pipelines/bitext/conf/audio_zip/base.yaml +++ b/stopes/pipelines/bitext/conf/audio_zip/base.yaml @@ -1,5 +1,5 @@ _target_: stopes.modules.speech.audio_zip.AudioZipModule -tsv_file: ??? +input_file: ??? output_dir: ??? output_prefix: "" sample_rate: 16000 @@ -14,6 +14,12 @@ output_validation_token: false column: ??? nshards: ??? +# If true, the audio paths will be sorted to speed up the audio reading (by taking advantage of +# memory cache when loading segments from the same file). This means the order of audios in the +# zip file is not the same as in the input TSV file. If False, no sorting is performed, the audios +# are appended to the zip file in the same order as in the input TSV file +sorted: true + # Whether to add duplicate segments to the zip file. Default true but this can be turned off # for debugging purpose -no_duplicate: true +no_duplicate: true \ No newline at end of file diff --git a/stopes/utils/arrow_utils.py b/stopes/utils/arrow_utils.py new file mode 100644 index 0000000..137c08d --- /dev/null +++ b/stopes/utils/arrow_utils.py @@ -0,0 +1,561 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import json +import typing as tp + +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc + + +class DataClassEncoder(json.JSONEncoder): + def default(self, obj): + import dataclasses + + if dataclasses.is_dataclass(obj): + return dataclasses.asdict(obj) + import omegaconf + + if isinstance(obj, omegaconf.dictconfig.DictConfig): + resolved = omegaconf.OmegaConf.to_container(obj, resolve=True) + return resolved + import enum + + if isinstance(obj, enum.Enum): + return obj.name # or return obj.name if you want the name of the enum + return super().default(obj) + + +def hash_table_with_schema(table, seed: int = 0) -> str: + """ + Computes a hash for a pyarrow.Table including its schema using xxHash. + This function serializes the schema of the table and updates the hash with it, + ensuring that any changes in the schema affect the hash result. It then iterates + over each column and chunk in the table, updating the hash with the data buffers. + This approach provides a comprehensive hash that reflects both the structure + and content of the table. + + Parameters: + - table (pyarrow.Table): The PyArrow table to hash. + - seed (int, optional): An optional seed for the xxHash function. Default is 0. + Returns: + - str: The hexadecimal string representing the hash of the table including its schema. + Example: + >>> data = { + 'column1': [1, 2, 3, 4], + 'column2': ['foo', 'bar', 'baz', 'qux'] + } + >>> table = pa.Table.from_pydict(data) + >>> hash_table_with_schema(table) + 394e32679db7eced + + """ + import xxhash + + hash_obj = xxhash.xxh64(seed=seed) + # Serialize the schema to a string and update the hash + schema_str = table.schema.serialize().to_pybytes() + hash_obj.update(schema_str) + + for column in table.itercolumns(): + for buffer in pyarrow_column_to_array(column).buffers(): + if buffer is not None: + hash_obj.update(buffer) + + return hash_obj.hexdigest() + + +def add_metadata_to_table(table: pa.Table, meta: dict) -> pa.Table: + existing_metadata = table.schema.metadata or {} + encoded_meta = { + key: json.dumps(val, cls=DataClassEncoder) for key, val in meta.items() + } + if existing_metadata: + encoded_meta["previous_metadata"] = json.dumps( + { + key.decode("utf-8"): val.decode("utf-8") + for key, val in existing_metadata.items() + }, + cls=DataClassEncoder, + ) + + return table.replace_schema_metadata(encoded_meta) + + +def is_list_like(arr): + return pa.types.is_list(arr.type) or pa.types.is_large_list(arr.type) + + +def _fix_list_offset(arr: pa.Array) -> pa.Array: + """ + Recursively fixes list offset to 0, so that arr.offsets are always starts from 0 + and can be used easily downstream. + """ + if not is_list_like(arr): + return arr + if arr.offset == 0: + return arr + + new_values = _fix_list_offset(pc.list_flatten(arr)) + new_offsets = pc.subtract(arr.offsets, arr.offsets[0]) + + return ( + pa.LargeListArray.from_arrays(new_offsets, new_values) + if pa.types.is_large_list(arr.type) + else pa.ListArray.from_arrays(new_offsets, new_values) + ) + + +def pyarrow_column_to_array(arg: tp.Union[pa.ChunkedArray, pa.Array]) -> pa.Array: + # see https://github.com/apache/arrow/issues/37318 + if isinstance(arg, pa.Array): + return _fix_list_offset(arg) + + return _fix_list_offset( + arg.chunk(0) if arg.num_chunks == 1 else arg.combine_chunks() + ) + + +def numpy_to_fixed_size_pyarrow_array(array: np.ndarray) -> pa.Array: + assert array.ndim == 2 + buffer = array.ravel(order="C") + return pa.FixedSizeListArray.from_arrays(pa.array(buffer), array.shape[1]) + + +def apply_on_nested_array( + fn: tp.Callable[[pa.Array], pa.Array], arr: tp.Union[pa.ChunkedArray, pa.Array] +) -> tp.Union[pa.ChunkedArray, pa.Array]: + if is_list_like(arr): + arr = pyarrow_column_to_array(arr) + res = apply_on_nested_array(fn, pc.list_flatten(arr)) + + assert arr.offset == 0 + cls = pa.LargeListArray if pa.types.is_large_list(arr.type) else pa.ListArray + output = cls.from_arrays(arr.offsets, res) + if arr.null_count > 0: + output = pc.if_else(pc.is_null(arr), None, output) + return output + + return fn(arr) + + +def pyarrow_fixed_size_array_to_numpy( + cc: tp.Union[pa.ChunkedArray, pa.Array], +) -> np.ndarray: + cc = pyarrow_column_to_array(cc) + assert cc.null_count == 0 + assert cc.type.list_size is not None + return np.reshape(np.asarray(pc.list_flatten(cc)), (-1, cc.type.list_size)) + + +def nested_pyarrow_to_torch(arr: pa.Array): + """ + Transforms are List[List[ListOfFixedSize]] to Nested Torch Tensors of shape : + - batch_size x SeqLen* x Dim + The Tensor representation of Seq of Vectors batch. + + One can use + >>> normal_torch_tensor = nested_pyarrow_to_torch(arr).to_padded_tensor(0.) + + Args: + arr (pa.Array): + + Returns: + torch.Tensor: + """ + import torch + + return torch.nested.as_nested_tensor( + arr.to_pandas().map(np.vstack).map(torch.from_numpy).tolist() + ) + + +def explode_table_include_null( + table: pa.Table, columns: tp.Union[str, tp.Sequence[str]] +) -> pa.Table: + """ + Similar to pandas.DataFrame.explode method for pyarrow.Table + >>> table = pa.table({'a': range(3), 'b': [[1, 2], None, [3, 4, 5]]}) + >>> explode_table_include_null(table, 'b').to_pandas() + a b + 0 0 1 + 1 0 2 + 2 2 3 + 3 2 4 + 4 2 5 + + + Args: + table (pa.Table): + columns (str): list type columns in table + + Returns: + pa.Table + """ + if isinstance(columns, str): + columns = [columns] + + assert len(columns) > 0 + + other_columns = list(table.schema.names) + for column in columns: + other_columns.remove(column) + + # checking compatibility + new_cols = [] + lengths = pc.list_value_length(pc.fill_null(table[columns[0]], [None])).to_numpy() + + for name in columns: + col = pc.fill_null(table[name], [None]) + # checking that all columns list structures are parallel + assert (lengths == pc.list_value_length(col).to_numpy()).all() + new_cols.append(pc.list_flatten(col)) + + if len(other_columns) > 0: + indices = pc.list_parent_indices(pc.fill_null(table[columns[0]], [None])) + result = table.select(other_columns).take(indices) + + for name, new_col in zip(columns, new_cols): + result = result.append_column( + pa.field(name, table.schema.field(name).type.value_type), new_col + ) + else: + result = pa.Table.from_arrays(new_cols, columns) + + return result + + +# numba njit +def _get_indices_and_offsets(lengths, max_seq_len): + new_lengths, res = [], [] + for i, ll in enumerate(lengths): + nb_full, remaining = ll // max_seq_len, ll % max_seq_len + + if remaining != 0: + res.append(np.full((nb_full + 1), i, dtype=np.int32)) + new_lengths.append( + np.array([max_seq_len] * nb_full + [remaining], dtype=np.int32) + ) + else: + res.append(np.full(nb_full, i, dtype=np.int32)) + new_lengths.append(np.array([max_seq_len] * nb_full, dtype=np.int32)) + + return ( + np.concatenate(res), + np.concatenate([np.array([0], dtype=np.int32)] + new_lengths).cumsum(), + ) + + +def _cast_fs16_to_int16(table: pa.Table) -> pa.Table: + # polars does not work with fs16 data type, but works with int16 + def _view_as_fs16(col): + if pa.types.is_fixed_size_list(col.type) and pa.types.is_float16( + col.type.value_type + ): + return col.view(pa.list_(pa.int16(), col.type.list_size)) + elif pa.types.is_float16(col.type): + return col.view(pa.int16()) + else: + return col + + out = {} + for col in table.column_names: + out[col] = apply_on_nested_array(_view_as_fs16, table[col]) + + return pa.Table.from_pydict(out) + + +def _cast_back_int16_to_fs16(table: pa.Table, reference_table: pa.Table) -> pa.Table: + # for compatibility with polars we cast int16 back to fs16 + # large_list to simple list + for col in table.column_names: + if pa.types.is_large_list(table[col].type) and pa.types.is_list( + reference_table[col] + ): + table = table.drop(col).append_column( + col, table[col].cast(pa.list_(table[col].type.value_type)) + ) + if table[col].type != reference_table[col].type: + casted_columns = pyarrow_column_to_array(table[col]).view( + reference_table[col].type + ) + table = table.drop(col).append_column(col, casted_columns) + return table + + +def explode_table_with_fixed_length( + table: pa.Table, columns: tp.Union[str, tp.Sequence[str]], max_seq_len: int +) -> pa.Table: + """ + This function takes an Apache Arrow Table, explodes it based on the specified columns, + and then rechunks the exploded table based on a specified sequence length + + ## Parameters: + - `table` (`pa.Table`): The input Apache Arrow Table that needs to be exploded and rechunked. + - `columns` (`tp.Union[str, tp.Sequence[str]]`): The column or columns on which the table should be exploded. + - `max_seq_len` (`int`): The sequence length for rechunking the exploded table. This should be a positive integer. + ## Returns: + - `pa.Table`: The rechunked Table after exploding on the specified columns. + + ## Example: + + >>> table = pa.Table.from_pydict({"col1": [[1, 2], [3, 4, 5, 6, 7], [8, 10], [11]], + ... "col2": [[-1, -2], [-3, -4, -5, -6, -7], [-8, -10], [-11]], + ... "col3": ["a", "b", "c", "d"]}) + >>> exploded_table = explode_table_with_fixed_length(table, ["col1", "col2"], 3) + >>> exploded_table.to_pandas() + col3 col1 col2 __doc_segments __doc_lengths + 0 [a, a, b] [1, 2, 3] [-1, -2, -3] [0, 0, 1] [2, 1] + 1 [b, b, b] [4, 5, 6] [-4, -5, -6] [0, 0, 0] [3] + 2 [b, c, c] [7, 8, 10] [-7, -8, -10] [0, 1, 1] [1, 2] + 3 [d] [11] [-11] [0] [1] + """ + assert max_seq_len > 0 + table = table.append_column("__doc_index", pa.array(np.arange(len(table)))) + flatten_table = explode_table_include_null(table, columns) + offsets = np.arange(0, len(flatten_table), max_seq_len, dtype=np.int64) + offsets = pa.array(np.concatenate([offsets, [len(flatten_table)]])) + + if len(flatten_table) < 2 * 32 - 1: + cls = pa.ListArray.from_arrays + else: + cls = pa.LargeListArray.from_arrays + + out = {} + for col in flatten_table.column_names: + out[col] = cls(offsets, pyarrow_column_to_array(flatten_table[col])) + out_table = pa.Table.from_pydict(out) + + # transforming "__doc_index" to ordered segments indices + # [1, 1, 1, 2, 2, 4, 5, 5] -> [0, 0, 0, 1, 1, 2, 3, 3] + arr = out_table["__doc_index"] + rows = [ + pc.value_counts(pc.list_flatten(arr.slice(i, 1))).field(1).to_numpy() + for i in range(len(arr)) + ] + doc_lengths = pa.array(rows) + doc_segments = pa.array( + [np.repeat(np.arange(len(xx), dtype=np.int32), xx) for xx in rows] + ) + + return ( + out_table.drop(["__doc_index"]) + .append_column("__doc_segments", doc_segments) + .append_column("__doc_lengths", doc_lengths) + ) + + +def explode_table_with_max_length( + table: pa.Table, columns: tp.Union[str, tp.Sequence[str]], max_seq_len: int +) -> pa.Table: + """ + Unrolling list array into smaller list with fixed max length. + If provided several `columns`, all columns are supposed to the parallel list structure. + + >>> table = pa.table({'a': range(5), 'b': [[2, 2], None, [3,3,3], [4,4,4,4], [5,5,5,5,5]]}) + >>> explode_table_with_max_length(table, "b", 2).to_pandas() + a b + 0 0 [2.0, 2.0] + 1 1 [nan] + 2 2 [3.0, 3.0] + 3 2 [3.0] + 4 3 [4.0, 4.0] + 5 3 [4.0, 4.0] + 6 4 [5.0, 5.0] + 7 4 [5.0, 5.0] + 8 4 [5.0] + >>> explode_table_with_max_length(table, "b", 3).to_pandas() + a b + 0 0 [2.0, 2.0] + 1 1 [nan] + 2 2 [3.0, 3.0, 3.0] + 3 3 [4.0, 4.0, 4.0] + 4 3 [4.0] + 5 4 [5.0, 5.0, 5.0] + 6 4 [5.0, 5.0] + """ + if isinstance(columns, str): + columns = [columns] + + assert len(columns) > 0 + + cols = [pc.fill_null(table[columns[0]], [None])] + lengths = pc.list_value_length(cols[0]).to_numpy() + + for name in columns[1:]: + col = pc.fill_null(table[name], [None]) + # checking that all columns list structures are parallel + assert (lengths == pc.list_value_length(col).to_numpy()).all() + cols.append(col) + + # next unroll with max_seq_len + indices, new_offests = _get_indices_and_offsets(lengths, max_seq_len) + + other_columns = list(table.schema.names) + for name in columns: + other_columns.remove(name) + + remaining_table = table.select(other_columns).take(indices) + + result_dict = {} + for name in other_columns: + result_dict[name] = remaining_table[name] + + for name, col in zip(columns, cols): + rolled_array = pa.ListArray.from_arrays( + offsets=new_offests, + values=pyarrow_column_to_array(pc.list_flatten(col)), + ) + result_dict[name] = rolled_array + + return pa.Table.from_pydict(result_dict, schema=table.schema) + + +def nested_numpy_to_pyarrow(series: tp.Union[list, np.ndarray]) -> pa.Array: + """ + >>> a = [np.random.rand(i + 1, 2) for i in range(10)] + >>> nested_array = nested_numpy_to_pyarrow(a) + >>> nested_array.type + ListType(list[2]>) + """ + offsets = np.array([0] + list(map(len, series)), dtype=np.int32).cumsum() + values = numpy_to_fixed_size_pyarrow_array(np.vstack(list(map(np.asarray, series)))) + return pa.ListArray.from_arrays(offsets, values) + + +def simple_array_to_nested(arr: tp.Union[pa.ChunkedArray, pa.Array]) -> pa.Array: + """ + >>> a = pa.array([1,2,3]) + >>> simple_array_to_nested(a).to_pylist() + [[1], [2], [3]] + """ + return pa.ListArray.from_arrays( + pa.array(np.arange(len(arr) + 1, dtype=np.int32)), pyarrow_column_to_array(arr) + ) + + +def hstack_pyarray_list(*arrays: tp.Union[pa.ChunkedArray, pa.Array]) -> pa.Array: + """ + Example with simple list: + >>> a = pa.array([[1], [2,3], [5], []]) + >>> b = pa.array([[-1, -3], [-11], [], [22]]) + >>> hstack_pyarray_list(a, b).to_pylist() + [[1, -1, -3], [2, 3, -11], [5], [22]] + + Example with nested lists : + >>> data = [np.array([[1, 2], [3, 4]]), np.array([[5, 6], [7, 8]]), np.array([[9, 10]])] + >>> list_array = nested_numpy_to_pyarrow(data) + >>> list_array.type + ListType(list[2]>) + >>> truncated_list_array = pc.list_slice(list_array, 1, 2) + [[[3, 4]], [[7, 8]], []] + >>> hstack_pyarray_list(list_array, truncated_list_array) + [[[1, 2], [3, 4], [3, 4]], + [[5, 6], [7, 8], [7, 8]], + [[9, 10]]] + """ + assert all(map(is_list_like, arrays)) + + lens = list(set(map(len, arrays))) + assert len(lens) == 1 + + list_off_views = [ + pyarrow_column_to_array(pc.list_flatten(arr.slice(i, 1))) + for i in range(lens[0]) + for arr in arrays + ] + + is_large = any(pa.types.is_large_list(arr.type) for arr in arrays) + + offsets = np.concatenate( + [np.array([0]), np.sum([pc.list_value_length(arr) for arr in arrays], axis=0)], + dtype=np.int64 if is_large else np.int32, + ).cumsum() + + cls = pa.LargeListArray if is_large else pa.ListArray + return cls.from_arrays(offsets, pa.concat_arrays(list_off_views)) + + +def apply_over_groups( + table: pa.Table, + grp_columns: tp.Optional[tp.List[tp.Optional[str]]], + table_mapper: tp.Callable[[pa.Table], pa.Table], +) -> pa.Table: + """ + Apply a mapping function to each group of a PyArrow table. + + Parameters: + - table: The input PyArrow table to be grouped and mapped. + - grp_columns: A list of column names to group the table by. + if `grp_columns=[None]` or `grp_columns=[]` or `grp_columns=None`, + `table_mapper` is applied on the full table. + Note also that None values in `grp_columns` will be filtered, + so one can you use grp_columns=[col1, col2] where each of col1 and col2 can be None. + - table_mapper: A callable function that takes a PyArrow table as input and returns a new PyArrow table. + + Returns: + - A new PyArrow table resulting from applying the `table_mapper` function to each group of the input table. + + Notes: + - The function adds a temporary column "__uuu_index" to the input table to facilitate grouping and sorting. + - The `table_mapper` function is applied to each group of the table, and the resulting tables are concatenated, + Therefore, the resulting sub-tables should have the same schema + - The function adds a temporary column "__uuu_index" to keep track of the original order of the rows. + So, it should be kept inchanged by `table_mapper`. + This column is removed in the final result. + """ + + # shortcut for no group case + if grp_columns is None: + return table_mapper(table) + + grp_columns = [x for x in grp_columns if x is not None] + if len(grp_columns) == 0: + return table_mapper(table) + + table = table.append_column( + "__uuu_index", pa.array(np.arange(len(table), dtype=np.int32)) + ) + split_grps = ( + table.select(grp_columns + ["__uuu_index"]) + .group_by(grp_columns) + .aggregate([("__uuu_index", "list")]) + ) + # shortcut for single group case + if len(split_grps) == 1: + return table_mapper(table.drop("__uuu_index")) + + # to iterate per rows we convert to pandas + # TODO : this could be called in parallel + results = [ + table_mapper(table.take(pa.array(ind))) + for ind in split_grps["__uuu_index_list"].to_pandas() + ] + + result = pa.concat_tables(results, promote_options="permissive") + del results + + if "__uuu_index" in result.column_names: + return result.sort_by("__uuu_index").drop("__uuu_index") + + return result.combine_chunks() + + +def first_element_in_nested_list(arr: tp.Union[pa.ChunkedArray, pa.Array]) -> pa.Array: + """ + >>> arr = pa.array([[[1, 2], [-1]], [[3, 2, 1], [], [4]] ]) + >>> first_element_in_nested_list(arr).to_pylist() + [[1, -1], [3, None, 4]] + """ + arr = pyarrow_column_to_array(arr) + return pa.ListArray.from_arrays( + arr.offsets, + pc.list_flatten( + pc.list_slice(arr.flatten(), start=0, stop=1, return_fixed_size_list=True) + ), + ) diff --git a/stopes/utils/config_utils.py b/stopes/utils/config_utils.py new file mode 100644 index 0000000..92dd67b --- /dev/null +++ b/stopes/utils/config_utils.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional + +from omegaconf import ListConfig + + +def parse_hydra_list(lst_config: Any) -> Optional[List]: + """ + robust parsing of hydra argumnts, accepting formats like + [1,2,3,5] OR [1] OR 1 OR 1,2 OR ['src','tgt'] + """ + if lst_config is None: + return None + if isinstance(lst_config, (list, ListConfig)): + lst_value: list = list(lst_config) + elif isinstance(lst_config, str): + lst_value = lst_config.strip().strip("[]").split(",") + else: + lst_value = [lst_config] + return lst_value diff --git a/stopes/utils/file_chunker_utils.py b/stopes/utils/file_chunker_utils.py index 1a5a324..4227e33 100644 --- a/stopes/utils/file_chunker_utils.py +++ b/stopes/utils/file_chunker_utils.py @@ -36,6 +36,34 @@ def find_offsets(filename: tp.Union[str, Path], num_chunks: int) -> tp.List[int] return offsets +def find_offsets_of_lines( + filename: tp.Union[str, Path], num_chunks: int, nrows: int +) -> tp.List[int]: + """ + Find the offsets of a text file that makes a total of `num_chunks` roughly equal-size chunks. + Here only the first `nrows` lines are read. This function should be used when `nrows` is + relatively small compared to the size of `filename`. + To find offsets of the entire file, please use `stopes.utils.file_chunker_utils.find_offsets()` + """ + offsets = [] + r = nrows % num_chunks + chunk_size = nrows // num_chunks + with open(filename, "r", encoding="utf-8") as f: + # Each of the r first chunks has one more line than the rest num_chunks - r + size = chunk_size + 1 + for _ in range(r): + offsets.append(f.tell()) + [f.readline() for _ in range(size)] + + for _ in range(0, num_chunks - r): + offsets.append(f.tell()) + [f.readline() for _ in range(chunk_size)] + + offsets.append(f.tell()) + + return offsets + + def find_line_numbers( filename: tp.Union[str, Path], start_offsets: tp.List[int] ) -> tp.List[int]: diff --git a/stopes/utils/language_codes.py b/stopes/utils/language_codes.py new file mode 100644 index 0000000..a9451b7 --- /dev/null +++ b/stopes/utils/language_codes.py @@ -0,0 +1,970 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +import typing as tp +from collections import defaultdict + +import numpy as np +import pandas as pd +from langcodes import Language, LanguageTagError +from sklearn.neighbors import BallTree +from tqdm.auto import tqdm + +from stopes.utils.web import cached_file_download + +logger = logging.getLogger(__name__) + + +# This is a copy of the languoid file from Glottolog (https://glottolog.org/meta/downloads), release 4.8. +# It is distributed by Glottolog under the CC-BY-4.0 license: https://creativecommons.org/licenses/by/4.0/ +GLOTTOLOG_DATA_URL = ( + "https://dl.fbaipublicfiles.com/nllb/languages/glottolog_languoid.csv" +) + + +def parse_language_code(lang: str) -> Language: + """Convert a language code (in any format) to a Language object, bypassing some formatting errors.""" + try: + lang_object = Language.get(lang) + return lang_object + except LanguageTagError as error: + # Testing the hi_IN_rom case + match = re.match("(?P[a-z]{2})_(?P[A-Z]{2})_rom", lang) + if match: + langtag = match.groupdict()["langtag"] + geotag = match.groupdict()["geotag"] + lang_object = Language.get(f"{langtag}_Latn_{geotag}") + return lang_object + raise error + + +def language_code_to_short_code( + orig_code: str, try_replacing_with_macro: bool = False +) -> str: + """ + Convert a language code (in any format) to its alpha-2 code, or, if it does not exist, to alpha-3. + If `try_replacing_with_macro` is set and the language does not have an alpha-2 code, but its macrolanguage does, + then the macrolanguage code is used (for example, Khalkha Mongolian `khk` may be "rounded" to just Mongolian `mn`). + """ + language: Language = parse_language_code(orig_code) + new_code = language.language + + if not isinstance(new_code, str): + logger.warning( + f"The code {orig_code} hasn't been matched, so language_code_to_short_code is returning it." + ) + return orig_code + + # Special case: `langcodes` package insists on renaming Tagalog to Filipino, but we don't want that rename. + # Filipino is a standardized version of Tagalog, so all Filipino is Tagalog, but not all Tagalog is Filipino. + if new_code == "fil" and orig_code.split("_")[0] in {"tgl", "tl"}: + new_code = "tl" + + if try_replacing_with_macro and len(new_code) == 3 and new_code in ISO_MICRO2MACRO: + code_macro = parse_language_code(ISO_MICRO2MACRO[new_code]).language + if isinstance(code_macro, str) and len(code_macro) == 2: + logger.info( + f"Replacing an alpha-3 code `{new_code}` (originally `{orig_code}`) with a macro-language alpha-2 code `{code_macro}`." + ) + new_code = code_macro + + return new_code + + +class LanguageMatcher: + """ + A class that matches the given language to the nearest neighbour from the given set of the target languages. + The proximity is determined by a cascade of criteria; see the `match` method docstring. + To work with it, please install the `[mono]` extra dependencies of Stopes. + Usage example: + ``` + from stopes.utils.language_codes import SONAR_LANGS, LanguageMatcher + matcher = LanguageMatcher() + matcher.set_target_langs(SONAR_LANGS) + print(matcher.match("es")) # spa_Latn + print(matcher.match("en-US")) # eng_Latn + print(matcher.match("ar")) # arb_Arab + print(matcher.match("zh_TW")) # zho_Hant + print(matcher.match("no_XX")) # nob_Latn + print(matcher.match("ber")) # tzm_Tfng + print(matcher.match("foo")) # None + print(matcher.match("ns")) # None + print(matcher.match("ab")) # kat_Geor + ``` + """ + + def __init__(self, glottolog_data_path: str = "auto"): + if glottolog_data_path == "auto": + glottolog_data_path = str( + cached_file_download(GLOTTOLOG_DATA_URL, "glottolog_languoid.csv") + ) + + self.glottolog_data: pd.DataFrame = pd.read_csv(glottolog_data_path) + # Checking that the table contains information about genetic and geographic relations + expected_columns = {"id", "parent_id", "iso639P3code", "latitude", "longitude"} + assert not expected_columns.difference(self.glottolog_data.columns) + + # Imputting codes for some weird languages, such as "Berber" + for short_code, glottolog_code in MISSING_P3_CODES_IN_GLOTTOLOG.items(): + mask = (self.glottolog_data["id"] == glottolog_code) & ( + self.glottolog_data["iso639P3code"].isnull() + ) + if sum(mask) > 0: + self.glottolog_data.loc[mask, "iso639P3code"] = short_code + + self.code2row: tp.Dict[str, int] = ( + self.glottolog_data.iso639P3code.dropna() + .reset_index() + .set_index("iso639P3code")["index"] + .to_dict() + ) + self.fullid2row: tp.Dict[str, int] = ( + self.glottolog_data.id.dropna() + .reset_index() + .set_index("id")["index"] + .to_dict() + ) + self.row2children: tp.DefaultDict[int, tp.Set[int]] = defaultdict(set) + for i, row in self.glottolog_data.iterrows(): + if not pd.isna(row["parent_id"]): + self.row2children[self.fullid2row[row["parent_id"]]].add(row.name) + + # The attributes below will be set when setting the target languages + self.target_langs_set: tp.Set[str] = set() + self.target_langs_3_to_all: tp.Dict[str, tp.Set[str]] = {} + self.target_langs_pop: tp.Dict[str, int] = {} + self.matched_data: tp.Optional[pd.DataFrame] = None + self.coordinate_tree: tp.Optional[BallTree] = None + + def strip_script(self, language: str) -> str: + """Remove the script information from NLLB-styled language code: e.g. `eng_Latn` => `eng`.""" + message = f"Target languages should be formatted as `eng_Latn` or `eng`; got {language} instead." + assert re.match("^[a-z]{3}(_[a-zA-Z]{4})?$", language), message + return language[:3] + + def set_target_langs(self, target_langs_list: tp.List[str]): + """ + Args: + target_langs_list: a list of languages in the 'eng_Latn' form + """ + self.target_langs_set = set(target_langs_list) + target_langs_3_to_all = defaultdict(set) + for lang in target_langs_list: + target_langs_3_to_all[self.strip_script(lang)].add(lang) + self.target_langs_3_to_all = dict(target_langs_3_to_all) + + self.target_langs_pop = { + self.strip_script(lang): max( + Language.get(self.strip_script(lang)).speaking_population(), + POPULATIONS.get(lang, 1), + ) + for lang in sorted(target_langs_list) + } + self.glottolog_data[MATCH_COLUMN] = self.glottolog_data["iso639P3code"].apply( + lambda x: x if x in self.target_langs_3_to_all else None + ) + self.glottolog_data[MATCH_POP_COLUMN] = self.glottolog_data.iso639P3code.apply( + self.target_langs_pop.get + ) + for lang_code, population in tqdm( + sorted(self.target_langs_pop.items(), key=lambda x: -x[1]) + ): + code3 = lang_code + if code3 not in self.code2row: + code3 = self.get_micro_language(code3) or code3 + if code3 not in self.code2row: + logger.warning(f"Could not find the target language: {code3}") + row_id = self.code2row[code3] + row = self.glottolog_data.loc[row_id] + while True: + if pd.isna(row.parent_id): + break + row = self.glottolog_data.loc[self.fullid2row[row.parent_id]] + if ( + not pd.isna(row[MATCH_POP_COLUMN]) + and row[MATCH_POP_COLUMN] > population + ): + break + self.glottolog_data.loc[row.name, MATCH_POP_COLUMN] = population + self.glottolog_data.loc[row.name, MATCH_COLUMN] = lang_code + self.matched_data = self.glottolog_data.dropna( + subset=["latitude", "longitude", MATCH_COLUMN] + ) + self.coordinate_tree = BallTree( + np.radians(self.matched_data[["latitude", "longitude"]]), metric="haversine" + ) + + def match(self, orig_code: str) -> tp.Optional[str]: + """ + For the input language code, find the most similar language in the set of target codes. + Try consecutively: + - Exact matching + - Exact matching after formatting the language with `langcodes` package + - Matching using the mapping of individual and macro languages + - Fuzzy matching to nearest genetic relative, using Gloggolog language genealogy tree + - Fuzzy matching to nearest georaphic neigbour, using Gloggolog language coordinates + """ + if orig_code in self.target_langs_set: + return orig_code + lang_obj_raw = parse_language_code(orig_code) + try: + code3 = lang_obj_raw.to_alpha3() + except LookupError: + logger.warning( + f"Could not parse the language code '{orig_code}'; matching it to none." + ) + return None + # some languages are parsed weirdly, e.g. prs is mapped to fas-AF, but fas is mapped back to pes + # TODO: take the territory into account, to fix this problem + result = self.choose_script_if_matched(code3, orig_code=orig_code, strict=True) + if result: + return result + result = self.find_gen_substitute(orig_code) + logger.warning(f"Fuzzy lookup for language {orig_code} => {result}") + return result + + def choose_script_if_matched( + self, code3: str, orig_code: tp.Optional[str] = None, strict=False + ) -> tp.Optional[str]: + """ + If the language is in the set of target languoids, + choose the languoid with the maching script, and return it. + """ + candidates = self.target_langs_3_to_all.get(code3, set()) + + # If there are several scripts, try choosing the one + if len(candidates) > 1: + scripted = parse_language_code(orig_code or code3) + self.assume_script_(scripted) + candidates_new = {c for c in candidates if c[4:] == scripted.script} + if len(candidates_new) == 1: + candidates = candidates_new + elif not strict: + # TODO: try choosing a script in a less arbitrary way + logger.warning( + f"For {orig_code}, found several scripts: {candidates} ({len(candidates_new)} matching)" + ) + if len(candidates) == 1 or len(candidates) > 1 and not strict: + return list(candidates)[0] + # if no scripts are matched, returning None; the later steps will do fuzzier search + return None + + def find_geo_substitute(self, orig_code, fallback=True, not_the_same=False): + """ + For the given language code, find its nearest geographic neighbour (as represented by Glottolog coordinates) + that belongs to the set of the target languages. + If nothing found, fall back to matching by genetic proximity. + """ + code3 = self.standardize_code(orig_code) + row_id = self.code2row[code3] + row = self.glottolog_data.loc[row_id] + if not pd.isnull(row[MATCH_COLUMN]): + if row[MATCH_COLUMN] != orig_code or not not_the_same: + return row[MATCH_COLUMN] + if pd.isna(row.latitude) or pd.isna(row.longitude): + if fallback: + return self.find_gen_substitute(orig_code, fallback=False) + return + + assert ( + self.coordinate_tree is not None and self.matched_data is not None + ), "Please set the target language codes before the matching." + distances, indices = self.coordinate_tree.query( + np.radians([[row.latitude, row.longitude]]), k=50 + ) + neighbours = self.matched_data.iloc[indices[0]] + if not_the_same: + neighbours = neighbours[neighbours[MATCH_COLUMN] != orig_code] + if neighbours.shape[0] > 0: + matched_language = neighbours[MATCH_COLUMN].iloc[0] + return self.choose_script_if_matched(matched_language, strict=False) + + if fallback: + return self.find_gen_substitute(orig_code, fallback=False) + + def find_gen_substitute( + self, orig_code, verbose=False, fallback=True, not_the_same=False + ): + """ + For the given language code, find its nearest neighbour in the genetic tree (as represented by Glottolog) + that belongs to the set of the target languages. + In case of ambiguity, return the neighbour with highest population. + If nothing found, fall back to matching by geographic proximity. + """ + code3 = self.standardize_code(orig_code) + if code3 not in self.code2row: + logger.warning(f"Code `{code3}` not found in Glottolog!") + return + row_id = self.code2row[code3] + row = self.glottolog_data.loc[row_id] + if not pd.isnull(row[MATCH_COLUMN]): + if row[MATCH_COLUMN] != orig_code or not not_the_same: + return self.choose_script_if_matched(row[MATCH_COLUMN], strict=False) + while True: + if verbose: + print(f"{row['name']} : looking for genetic neighbours") + children = self.glottolog_data.loc[ + sorted(self.row2children.get(row.name, set())) + ] + fltr = children[MATCH_COLUMN].notnull() + if not_the_same: + fltr = fltr & (children[MATCH_COLUMN] != orig_code) + children = children[fltr] + if children.shape[0] > 0: + largest_child_lang = children[MATCH_COLUMN][ + children[MATCH_POP_COLUMN].idxmax() + ] + return self.choose_script_if_matched(largest_child_lang, strict=False) + if pd.isna(row.parent_id): + if verbose: + print("found nothing genealogically, falling back to geography") + if fallback: + return self.find_geo_substitute( + orig_code, fallback=False, not_the_same=not_the_same + ) + row = self.glottolog_data.loc[self.fullid2row[row.parent_id]] + + def standardize_code(self, lang: str) -> str: + """ + Try to standardize a language code to match one in the glottolog data, by: + 1. Formatting it with langcode package + 2. Mapping a macrolanguage to its arbitrary individual language. + """ + if lang in self.code2row: + return lang + lang_object = parse_language_code(lang) + code3 = lang_object.to_alpha3() + if code3 not in self.code2row: + # Trying to match the macrolanguage to its variety + child = self.get_micro_language(code3) + if child: + return child + return code3 + + def get_micro_language(self, code3: str) -> tp.Optional[str]: + """If code is a macro language, return an individual language code of its first (usually arbitrary) child.""" + if code3 in ISO_MACRO2MICRO: + for child in ISO_MACRO2MICRO[code3]: + if child in self.code2row: + logger.warning( + f"Replacing the macrolanguage {code3} with its arbitrary sub-language: {child}" + ) + return child + # If no micro-language was found, we do nothing. + return None + + def assume_script_(self, lang: Language): + """Modify a Language object, by assiging a script to it, if possible""" + lang.assume_script() + if lang.script is not None: + return + + # For Chinese, assuming Mandarin with simplified script in Mainland China, and with traditional one in Taiwan + if lang.language in {"zh", "cmn"}: + if lang.territory in {"CN", None, "XX"}: + lang.script = "Hans" + if lang.territory in {"TW"}: + lang.script = "Hant" + + +# The list of ~200 languages supported by SONAR sentence encoder +SONAR_LANGS = [ + "ace_Arab", + "ace_Latn", + "acm_Arab", + "acq_Arab", + "aeb_Arab", + "afr_Latn", + "ajp_Arab", + "aka_Latn", + "als_Latn", + "amh_Ethi", + "apc_Arab", + "arb_Arab", + "ars_Arab", + "ary_Arab", + "arz_Arab", + "asm_Beng", + "ast_Latn", + "awa_Deva", + "ayr_Latn", + "azb_Arab", + "azj_Latn", + "bak_Cyrl", + "bam_Latn", + "ban_Latn", + "bel_Cyrl", + "bem_Latn", + "ben_Beng", + "bho_Deva", + "bjn_Arab", + "bjn_Latn", + "bod_Tibt", + "bos_Latn", + "bug_Latn", + "bul_Cyrl", + "cat_Latn", + "ceb_Latn", + "ces_Latn", + "cjk_Latn", + "ckb_Arab", + "crh_Latn", + "cym_Latn", + "dan_Latn", + "deu_Latn", + "dik_Latn", + "dyu_Latn", + "dzo_Tibt", + "ell_Grek", + "eng_Latn", + "epo_Latn", + "est_Latn", + "eus_Latn", + "ewe_Latn", + "fao_Latn", + "fij_Latn", + "fin_Latn", + "fon_Latn", + "fra_Latn", + "fur_Latn", + "fuv_Latn", + "gaz_Latn", + "gla_Latn", + "gle_Latn", + "glg_Latn", + "grn_Latn", + "guj_Gujr", + "hat_Latn", + "hau_Latn", + "heb_Hebr", + "hin_Deva", + "hne_Deva", + "hrv_Latn", + "hun_Latn", + "hye_Armn", + "ibo_Latn", + "ilo_Latn", + "ind_Latn", + "isl_Latn", + "ita_Latn", + "jav_Latn", + "jpn_Jpan", + "kab_Latn", + "kac_Latn", + "kam_Latn", + "kan_Knda", + "kas_Arab", + "kas_Deva", + "kat_Geor", + "kaz_Cyrl", + "kbp_Latn", + "kea_Latn", + "khk_Cyrl", + "khm_Khmr", + "kik_Latn", + "kin_Latn", + "kir_Cyrl", + "kmb_Latn", + "kmr_Latn", + "knc_Arab", + "knc_Latn", + "kon_Latn", + "kor_Hang", + "lao_Laoo", + "lij_Latn", + "lim_Latn", + "lin_Latn", + "lit_Latn", + "lmo_Latn", + "ltg_Latn", + "ltz_Latn", + "lua_Latn", + "lug_Latn", + "luo_Latn", + "lus_Latn", + "lvs_Latn", + "mag_Deva", + "mai_Deva", + "mal_Mlym", + "mar_Deva", + "min_Latn", + "mkd_Cyrl", + "mlt_Latn", + "mni_Beng", + "mos_Latn", + "mri_Latn", + "mya_Mymr", + "nld_Latn", + "nno_Latn", + "nob_Latn", + "npi_Deva", + "nso_Latn", + "nus_Latn", + "nya_Latn", + "oci_Latn", + "ory_Orya", + "pag_Latn", + "pan_Guru", + "pap_Latn", + "pbt_Arab", + "pes_Arab", + "plt_Latn", + "pol_Latn", + "por_Latn", + "prs_Arab", + "quy_Latn", + "ron_Latn", + "run_Latn", + "rus_Cyrl", + "sag_Latn", + "san_Deva", + "sat_Beng", + "scn_Latn", + "shn_Mymr", + "sin_Sinh", + "slk_Latn", + "slv_Latn", + "smo_Latn", + "sna_Latn", + "snd_Arab", + "som_Latn", + "sot_Latn", + "spa_Latn", + "srd_Latn", + "srp_Cyrl", + "ssw_Latn", + "sun_Latn", + "swe_Latn", + "swh_Latn", + "szl_Latn", + "tam_Taml", + "taq_Latn", + "taq_Tfng", + "tat_Cyrl", + "tel_Telu", + "tgk_Cyrl", + "tgl_Latn", + "tha_Thai", + "tir_Ethi", + "tpi_Latn", + "tsn_Latn", + "tso_Latn", + "tuk_Latn", + "tum_Latn", + "tur_Latn", + "twi_Latn", + "tzm_Tfng", + "uig_Arab", + "ukr_Cyrl", + "umb_Latn", + "urd_Arab", + "uzn_Latn", + "vec_Latn", + "vie_Latn", + "war_Latn", + "wol_Latn", + "xho_Latn", + "ydd_Hebr", + "yor_Latn", + "yue_Hant", + "zho_Hans", + "zho_Hant", + "zsm_Latn", + "zul_Latn", +] + +# This mapping is extracted from https://iso639-3.sil.org/code_tables/download_tables +# For some macrolanguages, I put at the first position its preferred individual language +# (usually, the most widely used one, or the one with an official status) +# The rest is ordered alphabetically. +ISO_MACRO2MICRO = { + "aka": ["fat", "twi"], + "ara": [ + "arb", # Arabic => Modern Standard Arabic + "aao", + "abh", + "abv", + "acm", + "acq", + "acw", + "acx", + "acy", + "adf", + "aeb", + "aec", + "afb", + "ajp", + "apc", + "apd", + "arq", + "ars", + "ary", + "arz", + "auz", + "avl", + "ayh", + "ayl", + "ayn", + "ayp", + "bbz", + "pga", + "shu", + "ssh", + ], + "aym": ["ayc", "ayr"], + "aze": [ + "azj", + "azb", + ], # putting North Azerbaijani first, because it has an official status + "bal": ["bcc", "bgn", "bgp"], + "bik": ["bcl", "bhk", "bln", "bto", "cts", "fbl", "lbl", "rbl", "ubl"], + "bnc": ["ebk", "lbk", "obk", "rbk", "vbk"], + "bua": ["bxm", "bxr", "bxu"], + "chm": ["mhr", "mrj"], + "cre": ["crj", "crk", "crl", "crm", "csw", "cwd"], + "del": ["umu", "unm"], + "den": ["scs", "xsl"], + "din": ["dib", "dik", "dip", "diw", "dks"], + "doi": ["dgo", "xnr"], + "est": ["ekk", "vro"], + "fas": ["pes", "prs"], + "ful": ["ffm", "fub", "fuc", "fue", "fuf", "fuh", "fui", "fuq", "fuv"], + "gba": ["bdt", "gbp", "gbq", "gmm", "gso", "gya", "mdo"], + "gon": ["esg", "ggo", "gno", "wsg"], + "grb": ["gbo", "gec", "grj", "grv", "gry"], + "grn": ["gug", "gnw", "gui", "gun", "nhd"], + "hai": ["hax", "hdn"], + "hbs": ["bos", "cnr", "hrv", "srp"], + "hmn": [ + "blu", + "cqd", + "hea", + "hma", + "hmc", + "hmd", + "hme", + "hmg", + "hmh", + "hmi", + "hmj", + "hml", + "hmm", + "hmp", + "hmq", + "hms", + "hmw", + "hmy", + "hmz", + "hnj", + "hrm", + "huj", + "mmr", + "muq", + "mww", + "sfm", + ], + "iku": ["ike", "ikt"], + "ipk": ["esi", "esk"], + "jrb": ["ajt", "aju", "jye", "yhd", "yud"], + "kau": ["kby", "knc", "krt"], + "kln": ["enb", "eyo", "niq", "oki", "pko", "sgc", "spy", "tec", "tuy"], + "kok": ["gom", "knn"], + "kom": ["koi", "kpv"], + "kon": ["kng", "kwy", "ldi"], + "kpe": ["gkp", "xpe"], + "kur": ["ckb", "kmr", "sdh"], + "lah": ["hnd", "hno", "jat", "phr", "pmu", "pnb", "skr", "xhe"], + "lav": [ + "lvs", # Putting Standard Latvian first + "ltg", + ], + "luy": [ + "bxk", + "ida", + "lkb", + "lko", + "lks", + "lri", + "lrm", + "lsm", + "lto", + "lts", + "lwg", + "nle", + "nyd", + "rag", + ], + "man": ["emk", "mku", "mlq", "mnk", "msc", "mwk", "myq"], + "mlg": [ + "bhr", + "bjq", + "bmm", + "bzc", + "msh", + "plt", + "skg", + "tdx", + "tkg", + "txy", + "xmv", + "xmw", + ], + "mon": ["khk", "mvf"], + "msa": [ + "zsm", # Malay (macrolanguage) => Standard Malay (individual), a.k.a. Malaysian Malay + "bjn", + "btj", + "bve", + "bvu", + "coa", + "dup", + "hji", + "ind", + "jak", + "jax", + "kvb", + "kvr", + "kxd", + "lce", + "lcf", + "liw", + "max", + "meo", + "mfa", + "mfb", + "min", + "mly", + "mqg", + "msi", + "mui", + "orn", + "ors", + "pel", + "pse", + "tmw", + "urk", + "vkk", + "vkt", + "xmm", + "zlm", + "zmi", + ], + "mwr": ["dhd", "mtr", "mve", "rwr", "swv", "wry"], + "nep": ["dty", "npi"], + "nor": ["nno", "nob"], + "oji": ["ciw", "ojb", "ojc", "ojg", "ojs", "ojw", "otw"], + "ori": ["ory", "spv"], + "orm": ["gax", "gaz", "hae", "orc"], + "pus": ["pbt", "pbu", "pst"], + "que": [ + "cqu", + "qub", + "qud", + "quf", + "qug", + "quh", + "quk", + "qul", + "qup", + "qur", + "qus", + "quw", + "qux", + "quy", + "quz", + "qva", + "qvc", + "qve", + "qvh", + "qvi", + "qvj", + "qvl", + "qvm", + "qvn", + "qvo", + "qvp", + "qvs", + "qvw", + "qvz", + "qwa", + "qwc", + "qwh", + "qws", + "qxa", + "qxc", + "qxh", + "qxl", + "qxn", + "qxo", + "qxp", + "qxr", + "qxt", + "qxu", + "qxw", + ], + "raj": ["bgq", "gda", "gju", "hoj", "mup", "wbr"], + "rom": ["rmc", "rmf", "rml", "rmn", "rmo", "rmw", "rmy"], + "san": ["cls", "vsn"], + "sqi": ["aae", "aat", "aln", "als"], + "srd": ["sdc", "sdn", "src", "sro"], + "swa": ["swc", "swh"], + "syr": ["aii", "cld"], + "tmh": ["taq", "thv", "thz", "ttq"], + "uzb": ["uzn", "uzs"], + "yid": ["ydd", "yih"], + "zap": [ + "zaa", + "zab", + "zac", + "zad", + "zae", + "zaf", + "zai", + "zam", + "zao", + "zaq", + "zar", + "zas", + "zat", + "zav", + "zaw", + "zax", + "zca", + "zcd", + "zoo", + "zpa", + "zpb", + "zpc", + "zpd", + "zpe", + "zpf", + "zpg", + "zph", + "zpi", + "zpj", + "zpk", + "zpl", + "zpm", + "zpn", + "zpo", + "zpp", + "zpq", + "zpr", + "zps", + "zpt", + "zpu", + "zpv", + "zpw", + "zpx", + "zpy", + "zpz", + "zsr", + "ztc", + "zte", + "ztg", + "ztl", + "ztm", + "ztn", + "ztp", + "ztq", + "zts", + "ztt", + "ztu", + "ztx", + "zty", + ], + "zha": [ + "ccx", + "ccy", + "zch", + "zeh", + "zgb", + "zgm", + "zgn", + "zhd", + "zhn", + "zlj", + "zln", + "zlq", + "zqe", + "zyb", + "zyg", + "zyj", + "zyn", + "zzj", + ], + "zho": [ + "cmn", # Putting Mandarin as "Standard Chinese" first + "cdo", + "cjy", + "cnp", + "cpx", + "csp", + "czh", + "czo", + "gan", + "hak", + "hsn", + "lzh", + "mnp", + "nan", + "wuu", + "yue", + ], + "zza": ["diq", "kiu"], +} + + +ISO_MICRO2MACRO = { + micro: macro for macro, micros in ISO_MACRO2MICRO.items() for micro in micros +} + +# This mapping augments the one produced by Langcodes. +# It affects the choice between children languages when finding genetic neighbours +POPULATIONS = dict( + acm_Arab=28_000_000, # https://en.wikipedia.org/wiki/Mesopotamian_Arabic + acq_Arab=12_000_000, # https://en.wikipedia.org/wiki/Ta%CA%BDizzi-Adeni_Arabic divided by 2 + ajp_Arab=27_000_000, # https://en.wikipedia.org/wiki/South_Levantine_Arabic + als_Latn=1_800_000, # https://en.wikipedia.org/wiki/Tosk_Albanian + apc_Arab=27_000_000, # https://en.wikipedia.org/wiki/North_Levantine_Arabic divided by 2 + arb_Arab=270_000_000, # https://en.wikipedia.org/wiki/Modern_Standard_Arabic + ayr_Latn=1_700_000, # https://en.wikipedia.org/wiki/Aymara_language [choose central-southern?] + azb_Arab=13_000_000, # https://en.wikipedia.org/wiki/Azerbaijani_language#South_Azerbaijani + azj_Latn=9_000_000, # https://en.wikipedia.org/wiki/Azerbaijani_language#North_Azerbaijani + cjk_Latn=2_500_000, # https://en.wikipedia.org/wiki/Chokwe_language + dik_Latn=4_200_000, # https://en.wikipedia.org/wiki/Dinka_language [choose southwestern?] + gaz_Latn=45_500_000, # https://en.wikipedia.org/wiki/Oromo_language [choose western central?] + kbp_Latn=1_000_000, # https://en.wikipedia.org/wiki/Kabiye_language + khk_Cyrl=3_000_000, # https://en.wikipedia.org/wiki/Khalkha_Mongolian + kmr_Latn=16_000_000, # https://en.wikipedia.org/wiki/Kurmanji + knc_Arab=8_450_000, # https://en.wikipedia.org/wiki/Central_Kanuri + knc_Latn=8_450_000, # https://en.wikipedia.org/wiki/Central_Kanuri + lus_Latn=1_000_000, # https://en.wikipedia.org/wiki/Mizo_language + lvs_Latn=1_300_000, # https://en.wikipedia.org/wiki/Latvian_language minus Latgalian + npi_Deva=19_000_000, # https://en.wikipedia.org/wiki/Nepali_language + ory_Orya=35_000_000, # https://en.wikipedia.org/wiki/Odia_language + pbt_Arab=16_000_000, # https://en.wikipedia.org/wiki/Southern_Pashto + pes_Arab=57_000_000, # https://en.wikipedia.org/wiki/Iranian_Persian + plt_Latn=10_893_000, # https://en.wikipedia.org/wiki/Malagasy_language, chose Plateau + quy_Latn=918_000, # https://en.wikipedia.org/wiki/Ayacucho_Quechua + swh_Latn=18_000_000, # https://en.wikipedia.org/wiki/Swahili_language + taq_Latn=900_000, # https://en.wikipedia.org/wiki/Tamasheq_language + taq_Tfng=900_000, # https://en.wikipedia.org/wiki/Tamasheq_language + twi_Latn=16_000_000, # https://en.wikipedia.org/wiki/Twi + uzn_Latn=28_000_000, # https://en.wikipedia.org/wiki/Uzbek_language + ydd_Hebr=600_000, # https://en.wikipedia.org/wiki/Yiddish_dialects + zsm_Latn=33_000_000, # https://en.wikipedia.org/wiki/Malaysian_Malay +) + +MATCH_COLUMN = "matched_lang_code3" +MATCH_POP_COLUMN = "matched_lang_population" + +# Some language groups, like "Berber", are present in Glottolog data, but don't have an iso639-3 code. +MISSING_P3_CODES_IN_GLOTTOLOG = { + "ber": "berb1260", +} diff --git a/stopes/utils/parquet_dataloader.py b/stopes/utils/parquet_dataloader.py index ee884c8..4bbff9f 100644 --- a/stopes/utils/parquet_dataloader.py +++ b/stopes/utils/parquet_dataloader.py @@ -91,9 +91,7 @@ def __init__( pq.filters_to_expression(filters) if filters else None ) # split_row_groups=True is not supported yet - self.source_ds = pq.ParquetDataset( - self.dataset_path, validate_schema=True, filters=filters - ) + self.source_ds = pq.ParquetDataset(self.dataset_path, filters=filters) self.columns = columns or self.source_ds.schema.names assert set(self.columns).issubset(set(self.source_ds.schema.names)) diff --git a/stopes/utils/sharding/__init__.py b/stopes/utils/sharding/__init__.py new file mode 100644 index 0000000..0952fcc --- /dev/null +++ b/stopes/utils/sharding/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/stopes/utils/sharding/abstract_shards.py b/stopes/utils/sharding/abstract_shards.py new file mode 100644 index 0000000..2bba95d --- /dev/null +++ b/stopes/utils/sharding/abstract_shards.py @@ -0,0 +1,320 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import dataclasses +import functools +import typing as tp +from abc import ABC, abstractmethod +from collections import defaultdict +from contextlib import AbstractContextManager +from dataclasses import dataclass +from enum import Enum +from glob import iglob +from pathlib import Path + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.dataset as ds + +BatchFormat = Enum("BatchFormat", ["PANDAS", "NUMPY", "ARROW"]) +BatchType = tp.Union[pd.DataFrame, tp.Dict[str, np.ndarray], pa.Table] + + +def batch_length(batch: tp.Optional[BatchType]) -> int: + if batch is None: + return 0 + + if isinstance(batch, dict): + if len(batch) == 0: + return 0 + return len(next(batch.values())) # type: ignore + + return len(batch) + + +def batch_tail(batch: BatchType, nb_samples: int) -> BatchType: + """ + take nb_samples from the tail of the batch + """ + if isinstance(batch, pd.DataFrame): + return batch.tail(nb_samples) + elif isinstance(batch, dict): + raise ValueError("batch_tail cannot be implemented for dict") + elif isinstance(batch, pa.Table): + return batch.slice(len(batch) - nb_samples) + else: + raise ValueError("data type is not understood :", type(batch)) + + +def batch_to_table(batch: BatchType) -> pa.Table: + if isinstance(batch, pd.DataFrame): + return pa.Table.from_pandas(batch, preserve_index=False) + elif isinstance(batch, dict): + return pa.Table.from_pydict(batch) + elif isinstance(batch, pa.Table): + return batch.combine_chunks() + else: + raise ValueError("data type is not understood :", type(batch)) + + +def batch_to_pandas(batch: BatchType) -> pd.DataFrame: + if isinstance(batch, pd.DataFrame): + return batch + elif isinstance(batch, dict): + return pd.DataFrame(batch) + elif isinstance(batch, pa.Table): + return batch.to_pandas() + else: + raise ValueError("data type is not understood :", type(batch)) + + +def arrow_table_to_batch(table: pa.Table, batch_format: BatchFormat) -> BatchType: + if batch_format == BatchFormat.ARROW: + return table.combine_chunks() + elif batch_format == BatchFormat.PANDAS: + return table.to_pandas(split_blocks=True, self_destruct=True) + elif batch_format == BatchFormat.NUMPY: + return table.to_pydict() + else: + raise ValueError(f"Unknown batch format {batch_format}") + + +def concat_batches(list_of_batches: tp.List[BatchType]) -> tp.Optional[BatchType]: + if len(list_of_batches) == 0: + return None + + types_ = list(set(map(type, list_of_batches))) + assert len(types_) == 1 + common_type = list_of_batches[0] + + if isinstance(common_type, pd.DataFrame): + return pd.concat(list_of_batches, axis=0) + elif isinstance(common_type, pa.Table): + return pa.concat_tables(list_of_batches).combine_chunks() + elif isinstance(common_type, dict): + assert ( + len(set(tuple(bb.keys()) for bb in list_of_batches)) == 1 + ), "all batches should share the same keys" + return { + key: np.concatenate([batch[key] for batch in list_of_batches], axis=0) + for key in list_of_batches[0].keys() + } + else: + raise ValueError("data type is not understood", common_type) + + +@dataclass +class Shard(AbstractContextManager): + """ + An abstract dataclass class holding configuration that represents a piece of tabular data. + It also exposes methods to read this data (possibly in smaller batches) in various data formats. + + It optionally supports the reading only a part of the columns present in the data schema. + It uses `filter` for row level filtering with the implementation based on `pyarrow.dataset.Expression`: + - see https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + - example for "col1" and "col2" (present in data schema): + >>> import pyarrow.compute as pc + >>> (pc.field("col1") < pc.scalar(3)) | (pc.field("col2") > 7) + """ + + filter: tp.Optional[pa.dataset.Expression] + + def __post_init__(self) -> None: + ... + + @abstractmethod + def to_batches( + self, + batch_size: tp.Optional[int], + columns: tp.Optional[tp.List[str]] = None, + batch_format: BatchFormat = BatchFormat.PANDAS, + ) -> tp.Iterator[BatchType]: + """ + Return a sequential mini-batch iterator of given `batch-size` over the underlying data. + If `batch_size` is None, it will read the whole data as a single batch. + + Args: + batch_size (tp.Optional[int]): batch size to use + columns (tp.Optional[tp.List[str]], optional): columns to read. Defaults to None. + batch_format (BatchFormat, optional): type of batch container. Defaults to BatchFormat.PANDAS. + + Returns: + tp.Iterator[BatchType]: + """ + ... + + @abstractmethod + def __iter__(self) -> tp.Iterator[tp.Any]: + ... + + +@dataclass +class ShardWithSkip(Shard, ABC): + """ + A Shard, but with extra info to help skip rows at the beginning of the shard. + This is useful when doing checkpointing or resuming failed jobs. + """ + + skip_n_rows: int + + def __post_init__(self) -> None: + if not self.skip_n_rows: + self.skip_n_rows = 0 + + +@dataclass +class PartitionedDataMapperState: + iteration_index: int + iteration_value: Shard + written_batches_index: int + intermediate_index: int + input_rows_written: int = 0 + + +@dataclass +class InputShardingConfig(ABC): + """ + Define how to handle the input data into different shards. + + Args: + - input_file: input to the sharding (a file or data identifier). expected to be + tp.Union[str, tp.List[tp.Union[str, Path]], Path] + - filters_expr, str is python evaluated string corresponding to the row level filtering + implemented as `pyarrow.dataset.Expression`: + * see https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html + * for instance, + filters_expr= '(pc.field("col1") < 3) | (pc.field("col2") > 7)' + will result in valid filter if "col1" and "col2" are numeric columns present in data schema. + - batch_size (int): size of the batch within one shard + - columns: List of column headers to construct the shards. + - take: tp.Optional[int], if not None, can be used in subclasses to use only + `take` number of shards in `make_shards(...)` method and ignoring others shards. + This option can be used for debugging or sampling. + - skip_n_rows_per_shard: tp.Optional[int], if not None, can be used to skip a number of rows at the beginning + of this shard. Useful when doing checkpointing or resuming failed jobs. The shard will not skip automatically, this info is used in the mapper. + """ + + input_file: ( + tp.Any + ) # we expect it to be tp.Union[str, tp.List[tp.Union[str, Path]], Path] + batch_size: tp.Optional[int] = None + columns: tp.Optional[tp.List[str]] = None + batch_format: BatchFormat = BatchFormat.PANDAS + filters_expr: tp.Optional[str] = None # for passing e.g. through Hydra config + take: tp.Optional[int] = None + skip_n_rows_per_shard: tp.Dict[int, int] = dataclasses.field( + default_factory=lambda: defaultdict(int) + ) + + def __post_init__(self): + self.filter: tp.Optional[pa.compute.Expression] = ( + eval(self.filters_expr, {}, {"pa": pa, "ds": ds, "pc": pc}) + if self.filters_expr + else None + ) + self.validate() + + @functools.cached_property + def input_dataset(self) -> tp.Sequence[Path]: + """ + Parse the input_file and construct the input_dataset as a list. + Default behaviour (can be overriden in concrete subclass of InputConfig): + 1) If the input_file is a str, assume this is glob pattern and construct + input_dataset as a list of matching paths + 2) If input_file is a Path, make input_dataset to be a single-item out of it + 3) If input_file is a a list of string, treat each item as a glob pattern + and repeat 1) + """ + if isinstance(self.input_file, str): + return sorted(map(Path, iglob(self.input_file))) + elif isinstance(self.input_file, Path): + return [self.input_file] + elif isinstance(self.input_file, (list, tuple)): + return [ + y for f in self.input_file for y in sorted(map(Path, iglob(str(f)))) + ] + else: + raise ValueError( + f"Unsupported input_file format {self.input_file} of type {type(self.input_file)}" + ) + + def validate(self) -> None: + """Method that validates the files existence/readability + other params compatibility (filters, columns, ...)""" + pass + + @abstractmethod + def make_shards(self, **kwargs) -> tp.List[Shard]: + """ + returns a list shards corresponding to the given configuration + + """ + ... + + def head( + self, + nb_top: int = 5, + columns: tp.Optional[tp.List[str]] = None, + batch_format: tp.Optional[BatchFormat] = None, + ) -> BatchType: + shard = self.make_shards()[0] + with shard: + return next( + shard.to_batches( + nb_top, + columns=columns or self.columns, + batch_format=batch_format or self.batch_format, + ) + ) + + +@dataclass +class OutputDatasetConfig(ABC): + """ + Config defining how one shard outputs its content. + + Args: + - dataset_path: str, a folder where the output dataset will be written + - validate_schema: bool, if True, it makes sure that the all written batches follows the same schema + - batch_size, optional int, default=None. If provided, `write_batch` should be called as soon as the size of processed batch > `write_each_nb_samples` + This should allow to write intermediate results (without loosing them in case of errors/preemtions) and to free some memory. + - compression: str, the format specific compression that is applied on output files + * use `compression=None` to deactivate any compression + * use `compression="default"` to overwrite this value in subclasses + - for parquet, default compression is "snappy" + - for text, default compression is None (deactivated)""" + + dataset_path: str + validate_schema: bool = False # it's not yet completely supported + batch_size: tp.Optional[int] = None + compression: tp.Optional[str] = "default" + + def __post_init__(self) -> None: + self.expected_schema: tp.Optional[ + pa.Schema + ] = None # TODO: how to pass it through Hydra serialization ? + + if self.compression is not None: + self.compression = self.compression.lower() + + @abstractmethod + def write_batch( + self, + batch: BatchType, + iteration_index: tp.Sequence[int], + metadata: tp.Dict[str, tp.Any] = {}, + state_checkpoint: tp.Optional[PartitionedDataMapperState] = None, + ) -> tp.List[Path]: + ... + + @abstractmethod + def reload_state( + self, + shard: Shard, + ) -> tp.Optional[PartitionedDataMapperState]: + ... diff --git a/stopes/utils/sharding/hf_shards.py b/stopes/utils/sharding/hf_shards.py new file mode 100644 index 0000000..53d7cfc --- /dev/null +++ b/stopes/utils/sharding/hf_shards.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import functools +import os +import typing as tp +from dataclasses import dataclass +from pathlib import Path + +import pyarrow as pa +from datasets import Dataset, DownloadMode, concatenate_datasets, load_dataset + +from stopes.utils.sharding.abstract_shards import ( + BatchFormat, + BatchType, + InputShardingConfig, + Shard, + arrow_table_to_batch, +) + + +@dataclass +class HFShard(Shard): + """ + A wrapper over HuggingFace datatsets's Dataset to make it + compatible with stopes Shard and Mapper API. + + Args: + path_or_name (str or Path): Path to a local dataset, or name of the dataset + in HuggingFace Hubg + data_dir: HuggingFace-specific data dir (kind of subset of the dataset) + split: (str) Split of the data. If None, all splits will be downloaded. Can + accept "train", "test", "validation", or a HF split syntax (see + https://huggingface.co/docs/datasets/v1.11.0/splits.html) + use_cache (bool): If we should reuse the cached dataset, or download from + HF Hub. This param has no impact if `path_or_name` is a local directory. + Default True + index (int): Index of the shard + num_shards: Total number of shards in which the current one is one member + """ + + path_or_name: tp.Union[str, Path] + data_dir: tp.Optional[str] = None + split: tp.Optional[str] = None + cached: bool = True + index: tp.Optional[int] = None + num_shards: int = 1 + trust_remote_code: bool = False + + def __post_init__(self): + if self.filter: + raise NotImplementedError( + f"Arrow-syntax filter is not supported in HF shard. Get {self.filter}" + ) + if self.index is None: + assert self.num_shards == 1, f"Unknown shard index {self.index}" + else: + assert ( + self.index < self.num_shards + ), f"Cannot make shard {self.index} from {self.num_shards} shards" + + self._data: tp.Optional[Dataset] = None + + # We could only iterate the underlying dataset in "nornmal" mode (via __iter__) or + # in "converted" mode (via to_batches()) + # mode = 0 --> not started + # mode = 1 --> _data is being consumed via __iter__() + # mode = 2 --> _data is being consumed via to_batches() + self._mode = 0 + + def __enter__(self): + """ + Create the underlying dataset, without loading them into main memory + + Note: When we enter the first shard, if `path_or_name` is not a local directory, + the underlying dataset will be downloaded to a _central_ local data dir that is + shared by all workers. This local data can be customized via os environment + STOPES_HF_CACHE + """ + + _download_mode = ( + DownloadMode.REUSE_DATASET_IF_EXISTS + if self.cached + else DownloadMode.FORCE_REDOWNLOAD + ) + + _cache_dir = None + if not Path(self.path_or_name).is_dir() and self.cached: + # TODO: We trust that HF DownloadManager will perform proper locks to avoid + # concurrent downloads from multiple workers. If error occurs, consider using + # stopes FileLock explicity + _cache_dir = os.getenv( + "STOPES_HF_CACHE", Path.home() / ".cache" / "huggingface" / "datasets" + ) + + _data = load_dataset( + path=self.path_or_name, + data_dir=self.data_dir, + cache_dir=_cache_dir, + download_mode=_download_mode, + split=self.split, + trust_remote_code=self.trust_remote_code, + ) + if self.split is None: # _data is a DatasetDict, convert to Dataset + _data = concatenate_datasets( + [_data["train"], _data["test"], _data["validation"]] + ) + + if self.num_shards > 0: + _data = _data.shard(num_shards=self.num_shards, index=self.index) + self._data = _data + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self._mode = 0 + self._data = None + + def __iter__(self) -> tp.Iterator[tp.Union[tp.Dict, tp.List]]: + if self._data is None: + raise ValueError("shard is not entered yet") + assert self._mode != 2, "Consumption mode changed during iterating" + if self._mode == 0: + self._mode = 1 + yield from self._data + self._mode = 0 + + def to_batches( + self, + batch_size: tp.Optional[int], + columns: tp.Optional[tp.List[str]] = None, + batch_format: BatchFormat = BatchFormat.ARROW, + ) -> tp.Iterator[BatchType]: + if self._data is None: + raise ValueError("shard is not entered yet") + assert self._mode != 1, "Consumption mode changed during iterating" + if self._mode == 0: + self._mode = 2 + for item in self._data.iter(batch_size=batch_size): + _table = pa.Table.from_pydict(item) + _table = arrow_table_to_batch(_table, batch_format=batch_format) + # TODO: Move the column projection in early step (before mini batch begins) + if columns: + _table = _table.select(columns) + yield _table + self._mode = 1 + + +@dataclass +class HFInputConfig(InputShardingConfig): + data_dir: tp.Optional[str] = None + split: tp.Optional[str] = None + cached: bool = True + num_shards: int = 1 + trust_remote_code: bool = False + + def __post_init__(self): + super().__post_init__() + assert ( + len(self.skip_n_rows_per_shard) == 0 + ), "skipping not supported for this shard type" + + @functools.cached_property + def input_dataset(self): + return self.input_file + + def make_shards(self, **kwargs): + assert isinstance( + self.input_dataset, (Path, str) + ), f"Expect input dataset to be Path or str, get {type(self.input_dataset)}" + if self.filters_expr: + raise NotImplementedError("Not implemented yet for HF Shards") + + return [ + HFShard( + filter=None, + path_or_name=self.input_dataset, + data_dir=self.data_dir, + split=self.split, + cached=self.cached, + index=i, + num_shards=self.num_shards, + trust_remote_code=self.trust_remote_code, + ) + for i in range(min(self.num_shards, self.take or self.num_shards)) + ] diff --git a/stopes/utils/sharding/json_shards.py b/stopes/utils/sharding/json_shards.py new file mode 100644 index 0000000..9413bd6 --- /dev/null +++ b/stopes/utils/sharding/json_shards.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import json +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Union + +import pyarrow as pa +from typing_extensions import Self + +from stopes.core.utils import batch +from stopes.utils.file_chunker_utils import Chunker, find_offsets, find_offsets_of_lines +from stopes.utils.sharding.abstract_shards import ( + BatchFormat, + BatchType, + InputShardingConfig, + Shard, + arrow_table_to_batch, +) + + +@dataclass +class JSONShard(Shard): + input_file: Union[str, Path] + start_offset: int = 0 + end_offset: Optional[int] = None + + def __enter__(self) -> Self: + self.file_handler = Chunker( + str(self.input_file), + start_offset=self.start_offset, + end_offset=self.end_offset, + ) + self.reader = self.file_handler.__enter__() # type: ignore + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if hasattr(self, "file_handler"): + self.file_handler.__exit__(exc_type, exc_val, exc_tb) + del self.file_handler + if hasattr(self, "reader"): + self.reader = None + del self.reader + + def _select_columns( + self, + line: str, + columns: Optional[List[str]] = None, + ) -> Dict[str, Any]: + data = json.loads(line) + if columns is None: + return data + return {k: v for k, v in data.items() if k in columns} + + def __iter__(self): + if self.reader is None: + raise ValueError("shard is not entered yet") + lines = iter(self.reader) + mapper_func = partial(self._select_columns, columns=None) + yield from map(mapper_func, lines) # type: ignore + + def to_batches( + self, + batch_size: Optional[int], + columns: Optional[List[str]] = None, + batch_format: BatchFormat = BatchFormat.PANDAS, + ) -> Iterator[BatchType]: + mapper_func = partial(self._select_columns, columns=columns) + + with self as reading_context: + lines = iter(reading_context.reader) # type: ignore + lines = map(mapper_func, lines) # type: ignore + if batch_size is None: + # Read the whole file as a single batch + batched = [list(lines)] + else: + assert batch_size > 0, f"Invalid batch size: {batch_size}" + batched = batch(lines, batch_size=batch_size) # type: ignore + for _batch in batched: + table = pa.Table.from_pylist(_batch) + yield arrow_table_to_batch(table, batch_format) + + +@dataclass +class JSONShardConfig(InputShardingConfig): + input_file: Union[str, Path] + num_shards: int = 1 + nrows: Optional[int] = None + partition_columns: Optional[List[str]] = None + + def __post_init__(self): + super().__post_init__() + assert self.num_shards > 0, f"invalid number of shards ({self.num_shards})" + assert ( + len(self.skip_n_rows_per_shard) == 0 + ), "skipping not supported for this shard type" + + def validate(self) -> None: + assert Path(self.input_file).exists() + pass + + def make_shards(self, **kwargs) -> List[Shard]: + if self.nrows: + offsets = find_offsets_of_lines( + self.input_file, self.num_shards, self.nrows + ) + else: + offsets = find_offsets(self.input_file, self.num_shards) + return [ + JSONShard( + filter=None, + input_file=self.input_file, + start_offset=start, + end_offset=end, + ) + for start, end in zip(offsets, offsets[1:]) + ] diff --git a/stopes/utils/sharding/parquet_shards.py b/stopes/utils/sharding/parquet_shards.py new file mode 100644 index 0000000..02b9127 --- /dev/null +++ b/stopes/utils/sharding/parquet_shards.py @@ -0,0 +1,528 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import functools +import importlib.util +import logging +import typing as tp +import uuid +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union + +import cloudpickle +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.dataset as ds +import pyarrow.parquet as pq +import xxhash +from pyarrow.dataset import get_partition_keys +from tqdm.auto import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm + +from stopes.core.utils import batch as batched +from stopes.utils.arrow_utils import add_metadata_to_table, hash_table_with_schema +from stopes.utils.sharding.abstract_shards import ( + BatchFormat, + BatchType, + InputShardingConfig, + OutputDatasetConfig, + PartitionedDataMapperState, + Shard, + ShardWithSkip, + arrow_table_to_batch, + batch_length, + batch_to_table, +) + +logger = logging.getLogger("stopes.launcher") + + +import signal +from functools import wraps + + +class TimeoutException(Exception): + """Exception to raise on a timeout""" + + ... + + +def timeout(seconds=60, error_message="Function call timed out"): + def decorator(func): + def _handle_timeout(signum, frame): + raise TimeoutException(error_message) + + @wraps(func) + def wrapper(*args, **kwargs): + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(seconds) + try: + result = func(*args, **kwargs) + finally: + signal.alarm(0) # Cancel the alarm + return result + + return wrapper + + return decorator + + +def get_filesystem_from_path( + uri: Union[Union[str, Path], Sequence[Union[str, Path]]], + **kwargs, +) -> Tuple[Union[Union[str, Path], Sequence[Union[str, Path]]], Any]: + return uri, None + + +@dataclass +class ParquetShardBase(ShardWithSkip, ABC): + @abstractmethod + def mini_batches( + self, max_chunk_size: int, columns: tp.Optional[tp.List[str]] = None + ) -> tp.Iterator[pa.Table]: + """ + Returns an iterator over the mini-batches of this shard. Mini batches + might be aggregated to form batches of the requested size if they are too + small. + + Args: + max_chunk_size (int): Maximum size of the mini-batches. Mini batches might + be smaller. + """ + + def __iter__(self): + raise NotImplementedError( + "single-item iteration not yet implemented in parquetShard. " + "Use to_batches() with batch_size = 1 instead" + ) + + def to_batches( + self, + batch_size: tp.Optional[int], + columns: tp.Optional[tp.List[str]] = None, + batch_format: BatchFormat = BatchFormat.PANDAS, + ) -> tp.Iterator[BatchType]: + assert batch_size, "you need to specify a batch_size." + + table: pa.Table = None + + for new_table in self.mini_batches(batch_size, columns): + if self.filter is not None: + new_table = new_table.filter(self.filter) + # Note that the filters can reduce the number of rows, + # so we combine the results from several batches filtered mini-batches + if len(new_table) > 0: + if table is not None: + table = pa.concat_tables([table, new_table]) + else: + table = new_table + + if len(table) >= batch_size: + table_to_return = table.slice(0, batch_size).combine_chunks() + yield arrow_table_to_batch(table_to_return, batch_format) + + if len(table) == batch_size: + table = None + else: + table = table.slice(batch_size, len(table) - batch_size) + + # if we have a table left, it means that the last batch was smaller than `batch_size` + # yield the rest + if table is not None and len(table) > 0: + yield arrow_table_to_batch(table, batch_format) + + +@dataclass +class ParquetShard(ParquetShardBase): + fragment: List[pa.dataset.Fragment] + + def __post_init__(self) -> None: + self._first_fragment: pa.dataset.Fragment = self.fragment[0] + + def __enter__(self) -> "ParquetShard": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + + @property + def nb_rows(self) -> int: + return sum(frag.metadata.num_rows for frag in self.fragment) + + @functools.cached_property + def partition_columns(self): + return list( + get_partition_keys(self._first_fragment.partition_expression).keys() + ) + + @functools.cached_property + def columns(self) -> tp.List[str]: + return list(self._first_fragment.physical_schema.names) + self.partition_columns + + def fragment_columns( + self, columns: tp.Optional[tp.List[str]] = None + ) -> tp.Optional[tp.List[str]]: + if columns is None: + return None + + assert set(columns).issubset(set(self.columns)), ( + sorted(set(columns) - set(self.columns)), + self.columns, + ) + return sorted(set(columns) - set(self.partition_columns)) + + def _add_partitioning_columns( + self, + frag: pa.dataset.Fragment, + table: pa.Table, + columns: tp.Optional[tp.List[str]] = None, + ) -> pa.Table: + """ + When loading a single fragment, pyarrow does not add the partitioning columns, + so we need to do it manually. + """ + for key, val in get_partition_keys(frag.partition_expression).items(): + if columns is None or key in columns: + values = pa.DictionaryArray.from_arrays( + np.zeros(len(table), dtype=np.int32), [val] + ) + table = table.append_column(key, values) + + return table + + def mini_batches( + self, max_chunk_size: int, columns: tp.Optional[tp.List[str]] = None + ) -> tp.Iterator[pa.Table]: + frag_columns = self.fragment_columns(columns) + for frag in self.fragment: + for record in frag.to_batches( + batch_size=max_chunk_size, columns=frag_columns + ): + table = pa.Table.from_batches([record]) + table = self._add_partitioning_columns(frag, table, columns) + yield table + del table + + def to_batches( + self, + batch_size: tp.Optional[int], + columns: tp.Optional[tp.List[str]] = None, + batch_format: BatchFormat = BatchFormat.PANDAS, + ) -> tp.Iterator[BatchType]: + # TODO : if filter does not contain partitioned columns we can propagate to fragment.to_table(...) + yield from super().to_batches(batch_size or self.nb_rows, columns, batch_format) + + +FragmentSorting = Enum( + "FragmentSorting", ["DEFAULT", "RANDOM", "INCREASING_LENGTH", "DESCREASING_LENGTH"] +) + + +@functools.lru_cache(maxsize=10) +def get_dataset(input_dataset, filters_expr, filesystem) -> pq.ParquetDataset: + filter_: tp.Optional[pa.compute.Expression] = ( + eval(filters_expr, {}, {"pa": pa, "ds": ds, "pc": pc}) if filters_expr else None + ) + + return pq.ParquetDataset( + (input_dataset[0] if len(input_dataset) == 1 else list(input_dataset)), + filters=filter_, + filesystem=filesystem, + ) + + +@dataclass +class ParquetShardingConfig(InputShardingConfig): + """ + Config defining how one parquet shard outputs its content. + + Extra args: + - split_row_groups: whether to split the parquet fragments further by the row group + - sorting_strategy: fragment sorting strategy + - filesystem_expr: filesystem expression, permissible values: + * None -> to guess the filesystem from input_file ("s3://bucket/key") format (recommended) + * `s3fs` (using s3fs.core.S3FileSystem) + * `pyarrow_s3fs` (using pyarrow.fs.S3FileSystemt) + * Evaluable Python code (e.g `fs.S3FileSystem(region="us-west-2", role_arn=...)`) + + + Note that for manually provided `filesystem_expr`, one should use "bucket/key" as input file (without "s3://") !! + """ + + split_row_groups: bool = False + fragment_group_size: int = 1 + """ + This determines how many parquet fragments will be grouped to form a single shard. + Defaults to 1 + """ + nb_samples_per_group: tp.Optional[int] = None + """ + Allows to group several parquet fragments together so that the resulting shard will get ~ `nb_samples_per_group`. + Only partition filters are taking into account. + Defaults to None (not applied) + """ + + # aggregate_files: tp.Optional[tp.List[str]] = None + sorting_strategy: FragmentSorting = FragmentSorting.DEFAULT + filesystem_expr: tp.Optional[str] = None + + def validate(self) -> None: + if self.nb_samples_per_group is not None: + assert self.nb_samples_per_group > 0, "only positive values are accepted" + assert ( + self.fragment_group_size == 1 + ), "cannot use `fragment_group_size` with `group_to_nb_samples`" + _ = self.input_dataset # to init input dataset + + @functools.cached_property + def input_dataset(self) -> tp.Sequence[Path]: + if not isinstance(self.input_file, (list, tuple)): + # we expect single Path here + self.input_file = [str(self.input_file)] + + self.input_file, self.filesystem = get_filesystem_from_path( + self.input_file, filter=self.filesystem_expr + ) + self.input_file = tuple(self.input_file) + if self.filesystem is not None: + if hasattr(self.filesystem, "glob"): + return tuple( + [ + y + for f in self.input_file + for y in sorted(self.filesystem.glob(str(f))) + ] + ) + else: + return tuple(sorted(self.input_file)) + else: + return tuple(super().input_dataset) + + @functools.cached_property + def partition_columns(self) -> tp.List[str]: + dataset = get_dataset(self.input_dataset, self.filters_expr, self.filesystem) + partitioning = dataset.partitioning + if partitioning is None: + return [] + return [ + name + for name, dd in zip(partitioning.schema.names, partitioning.dictionaries) + if dd is not None + ] + + def make_shards(self, **kwargs) -> tp.List[Shard]: + dataset = get_dataset(self.input_dataset, self.filters_expr, self.filesystem) + fragments = list(dataset._dataset.get_fragments(self.filter))[: self.take] + + # TODO: making thread parallel when using S3 datasets + if self.split_row_groups: + fragments = [ + y for fragment in fragments for y in fragment.split_by_row_group() + ][: self.take] + + logger.info(f"Finding {len(fragments)} fragments") + + sorting_strategy = self.sorting_strategy + if sorting_strategy == FragmentSorting.RANDOM: + fragments = list(np.random.RandomState(None).permutation(fragments)) + + if self.nb_samples_per_group or sorting_strategy in [ + FragmentSorting.DESCREASING_LENGTH, + FragmentSorting.INCREASING_LENGTH, + ]: + nb_rows_per_fragment_ = [] + logger.info("Computing fragments rows count!") + + with logging_redirect_tqdm(): + for frag in tqdm(fragments): + nb_rows_per_fragment_.append(frag.count_rows()) + nb_rows_per_fragment = np.array(nb_rows_per_fragment_) + + if sorting_strategy in [ + FragmentSorting.DESCREASING_LENGTH, + FragmentSorting.INCREASING_LENGTH, + ]: + if sorting_strategy == FragmentSorting.DESCREASING_LENGTH: + permutation = np.argsort(-nb_rows_per_fragment, kind="stable") + else: + permutation = np.argsort(nb_rows_per_fragment, kind="stable") + nb_rows_per_fragment = nb_rows_per_fragment[permutation] + fragments = [fragments[i] for i in permutation] + + if self.nb_samples_per_group: + shards_list: tp.List[tp.List[pa.dataset.Fragment]] = [] + current_nb_samples = 0 + current_list = [] + for size, frag in zip( + nb_rows_per_fragment[: self.take], fragments[: self.take] + ): + current_list.append(frag) + current_nb_samples += size + if current_nb_samples >= self.nb_samples_per_group: + shards_list.append(current_list) + current_list = [] + current_nb_samples = 0 + + if current_list: # remainder + shards_list.append(current_list) + return [ + ParquetShard( + fragment=frags, + filter=self.filter, + skip_n_rows=self.skip_n_rows_per_shard.get(i, 0), + ) + for i, frags in enumerate(shards_list) + ] + + return [ + ParquetShard( + fragment=list(frags), + filter=self.filter, + skip_n_rows=self.skip_n_rows_per_shard.get(i, 0), + ) + for i, frags in enumerate( + batched(fragments[: self.take], self.fragment_group_size) + ) + ] + + +@dataclass +class ParquetOutputConfig(OutputDatasetConfig): + """ + For s3 files, there're two possible options : + * dataset_path = "s3://bucket/key/" and filesystem_expr = None (automatically getting client) + * dataset_path = "bucket/key/" and filesystem_expr = "s3fs" + + """ + + row_group_size: tp.Optional[int] = None + max_rows_per_file: tp.Optional[int] = None + keep_same_partitioning: bool = True + partition_columns: tp.Optional[tp.List[str]] = None + filesystem_expr: tp.Optional[str] = None + blobstore_expiration_timestamp: Optional[int] = None + + def __post_init__(self): + super().__post_init__() + if self.keep_same_partitioning and self.partition_columns is not None: + raise ValueError( + "cannot provide `partition_cols` when `keep_same_partining` is True" + ) + if self.compression == "default": + self.compression = "snappy" + assert self.compression in [ + None, + "none", + "snappy", + "gzip", + "brotli", + "lz4", + "zstd", + ] + self.dataset_path, self.filesystem = get_filesystem_from_path( # type: ignore + str(self.dataset_path), + filter=self.filesystem_expr, + expiration_timestamp=self.blobstore_expiration_timestamp, + ) + if self.filesystem is None: + Path(self.dataset_path).mkdir(parents=True, exist_ok=True) + else: + try: + self.filesystem.create_dir(self.dataset_path, recursive=True) + except Exception: # noqa + try: + self.filesystem.mkdir(self.dataset_path, create_parents=True) + except Exception: + pass + + def write_batch( + self, + batch: tp.Optional[BatchType], + iteration_index: tp.Sequence[int], + metadata: tp.Optional[tp.Dict[str, tp.Any]] = None, + state_checkpoint: tp.Optional[PartitionedDataMapperState] = None, + ) -> tp.List[Path]: + if batch is None or batch_length(batch) == 0: + # TODO: logger empty batch + return [] + table = batch_to_table(batch) + partition_cols: tp.Optional[tp.List[str]] = self.partition_columns + + try: + guid = hash_table_with_schema(table)[:20] + except Exception as e: + logger.warn(f"`hash_table_with_schema` failed : {e}") + guid = f"{uuid.uuid4()}"[:20] + + basename_template = "{i}_" + f"{guid}" + iteration_index = ( + (iteration_index,) if isinstance(iteration_index, int) else iteration_index + ) + for idx in iteration_index: + basename_template += f"_{idx}" + basename_template += ".parquet" + # Thus final template will look as follows: + # `{file_number}_{guid}_{batch_number}_{shard_number}.parquet` + + if metadata is not None: + try: + table = add_metadata_to_table(table, metadata) + except Exception as e: + logger.warn(f"`add_metadata_to_table` failed : {e}") + + written_files = [] + + def collect_files(f: pa._dataset.WrittenFile): + written_files.append(Path(f.path)) + + pq.write_to_dataset( + table, + self.dataset_path, + partition_cols=partition_cols, + max_rows_per_file=self.max_rows_per_file, + filesystem=self.filesystem, + schema=self.expected_schema if self.validate_schema else None, + basename_template=basename_template, + use_threads=True, + file_visitor=collect_files, + **{ + "row_group_size": self.row_group_size or self.max_rows_per_file, + "compression": self.compression, + }, + ) + + if state_checkpoint: + with self._open(state_checkpoint.iteration_value, "wb") as f: + cloudpickle.dump(state_checkpoint, f) + + return sorted(written_files) + + def reload_state( + self, + shard: Shard, + ) -> tp.Optional[PartitionedDataMapperState]: + try: + with self._open(shard, "rb") as f: # filename is wrong + return cloudpickle.load(f) + except: + return None + + @contextmanager + def _open(self, shard: Shard, mode: str = "r"): + shard_hash = xxhash.xxh3_64_intdigest(cloudpickle.dumps(shard)) + fname = Path(self.dataset_path) / f".parquet_output.{shard_hash}.state" + if self.filesystem is None: + with fname.open(mode) as f: + yield f + else: + with self.filesystem.open(str(fname), mode) as f: + yield f diff --git a/stopes/utils/sharding/text_shards.py b/stopes/utils/sharding/text_shards.py new file mode 100644 index 0000000..5dc7175 --- /dev/null +++ b/stopes/utils/sharding/text_shards.py @@ -0,0 +1,834 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# +# Different methods to support sharding for audio files + +import functools +import io +import itertools +import logging +import shutil +import typing as tp +import uuid +from dataclasses import dataclass +from glob import glob +from pathlib import Path + +import cloudpickle +import numpy as np +import pandas as pd +import pyarrow as pa +import xxhash +from omegaconf.listconfig import ListConfig +from pyarrow import csv as csv_pa + +from stopes.core.utils import expand_if_compressed +from stopes.core.utils import open as stopes_open +from stopes.core.utils import sort_file +from stopes.utils.arrow_utils import hash_table_with_schema +from stopes.utils.file_chunker_utils import find_offsets_of_lines +from stopes.utils.sharding.abstract_shards import ( + BatchFormat, + BatchType, + InputShardingConfig, + OutputDatasetConfig, + PartitionedDataMapperState, + Shard, + arrow_table_to_batch, + batch_length, + batch_to_pandas, + batch_to_table, +) + + +@functools.lru_cache(10) +def warn_once(msg: str) -> None: + """Prevents flooding stderr with the same repeated error message.""" + log.warning(msg) + + +log = logging.getLogger("stopes.speech.shards") + + +@dataclass +class TextShard(Shard): + """ + input to one worker processing a file shard for an array module. Default behaviour + is that the worker will process an entire file. + + A shard is a contextmanager object: When you enter a shard in a local job, it gives + access from the input file resource (by default via `stopes.core.utils.open()`). + + A shard is also an iterator: You lazily reads each line after entering the shard. It + will update the internal states silently, to ensure the reading can be recovered if + the job needs to be re-run. + Note that this recovery is only guaranteed within one (slurm) job or machine, and not + if the whole pipeline is re-run, because a Shard object - once created - will be sent + and kept locally to each job only. + + Args: + input_file (Path): The input file. + columns (list or bool, optional): a list of header columns. None if there is no header + sep (optional): the separator of lines. Only applicable when `cols` is not None + index: index of the shard. None if there is only one shard for the file + path_column : when not None, means the column's name (returned with to_batches()) + containing the file path from which the corresponding data is read. + If None, no extra column is added. + + """ + + input_file: tp.Union[str, Path] + columns: tp.Optional[tp.List[str]] = None + sep: tp.Optional[str] = None + index: tp.Optional[int] = None + path_column: tp.Optional[str] = None + + def __post_init__(self): + """Prepare internal properties""" + super().__post_init__() + # Keep how many lines already processed. Use to re-run the job + self._lines_cnt: int = 0 + + # handle the input resource + self._input_handler: tp.Optional[tp.ContextManager] = None + self._reader: tp.Optional[tp.Iterator[str]] = None + + def __enter__(self) -> "TextShard": + if not Path(self.input_file).exists(): + raise FileNotFoundError(self.input_file) + self._reader = self.input_handler.__enter__() + return self + + @property + def input_handler(self) -> tp.ContextManager: + if self._input_handler is None: + self._input_handler = stopes_open(self.input_file) + return self._input_handler + + def resolve_column_index(self, column_name: tp.Union[int, str]) -> int: + if isinstance(column_name, int) or column_name.isdecimal(): + return int(column_name) + assert ( + isinstance(self.columns, tp.List) and len(self.columns) > 0 + ), f"{self.input_file} has no header" + try: + return self.columns.index(column_name) + except ValueError: + raise ValueError( + f"Column {column_name} not found in header of {self.input_file}: {self.columns}" + ) + + def value(self, column_name: tp.Union[int, str]) -> str: + """Get value from a given column in the current line""" + + column_offset = self.resolve_column_index(column_name) + lines = self.line.rstrip().split(self.sep) + return lines[column_offset] + + def __iter__(self) -> tp.Iterator[str]: + """start or resume the input file consumption from the last attempt.""" + + lines = iter(self._reader) # type: ignore + if self.has_started(): + log.info( + f"Resuming from previous attempt, already processed {self._lines_cnt} lines" + ) + # Skip processed lines + skipped_lines = int(self.contains_header()) + self._lines_cnt + for line in itertools.islice(lines, skipped_lines, None): + # Keep track of current line and processed lines so far + self.line = line + self._lines_cnt += 1 + yield line + + def has_started(self): + """whether the shard is already (partially) processed""" + return self._lines_cnt > 0 + + def contains_header(self) -> bool: + """whether the corresponding shard contains header""" + return bool(self.columns) and not bool(self.index) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.input_handler.__exit__(exc_type, exc_val, exc_tb) + self._input_handler = None + self._reader = None + + def to_batches( + self, + batch_size: tp.Optional[int], + columns: tp.Optional[tp.List[str]] = None, + batch_format: BatchFormat = BatchFormat.PANDAS, + ) -> tp.Iterator[BatchType]: + assert batch_size is None or batch_size > 0 + + if columns is not None and self.columns is not None: + assert set(columns).issubset(set(self.columns)) + + if columns is None: + columns = self.columns + + read_options = csv_pa.ReadOptions( + use_threads=True, column_names=self.columns, encoding="utf8" + ) + parse_options = csv_pa.ParseOptions(delimiter=self.sep, ignore_empty_lines=True) + convert_options = csv_pa.ConvertOptions(include_columns=columns) + + old_lines_cnt, self._lines_cnt = self._lines_cnt, 0 + with self as reading_context: + stream = io.BytesIO("".join(reading_context).encode()) + table = csv_pa.read_csv( + stream, + read_options=read_options, + parse_options=parse_options, + convert_options=convert_options, + ) + stream.close() + self._lines_cnt = old_lines_cnt + + if self.path_column: + table = table.append_column( + self.path_column, + pa.DictionaryArray.from_arrays( + np.zeros(len(table), dtype=np.int32), [str(self.input_file)] + ), + ) + + if self.filter is not None: + table = table.filter(self.filter) + if len(table) > 0: + if batch_size is None: + yield arrow_table_to_batch(table, batch_format) + else: + for tt in table.to_batches(max_chunksize=batch_size): + min_table = pa.Table.from_batches([tt]) + yield arrow_table_to_batch(min_table, batch_format) + + +@dataclass +class TextShardingConfig(InputShardingConfig): + nb_shards: int = 1 + sharding_strategy: str = "chunk" + header: tp.Any = True # should be tp.Optional[tp.Union[bool, tp.List[str]]] + sep: tp.Optional[str] = None + cache_dir: tp.Optional[Path] = None + path_column: tp.Optional[str] = None + # TODO: restrict only on supported sharding strategies + # TODO: split by given number of rows + """ + - header: either bool (True meaning the presence of header in a file) or explicit list of resulting column names + - path_column : when not None, means the column's name (returned with shard.to_batches()) + containing the file path from which the corresponding data is read. If None, no extra column is added. + """ + + def __post_init__(self): + super().__post_init__() + assert self.nb_shards > 0, f"invalid number of shards ({self.nb_shards})" + assert ( + len(self.skip_n_rows_per_shard) == 0 + ), "skipping not supported for this shard type" + + def validate(self) -> None: + # TODO: verify that files exists and are readable with provided parameters + pass + + def make_shards(self, **kwargs) -> tp.List[Shard]: + shards: tp.List[Shard] = list( + make_text_file_shards( + input=( + self.input_dataset[0] + if len(self.input_dataset) == 1 + else self.input_dataset + ), + nshards=self.nb_shards, + algo=self.sharding_strategy, + header=self.header, # type: ignore + sep=self.sep, + cache_dir=self.cache_dir, + filter=self.filter, + **kwargs, + ) + ) + shards = shards[: self.take] + if self.path_column: + for shard in shards: + shard.path_column = self.path_column # type: ignore + return shards + + +@dataclass +class TextOutputConfig(OutputDatasetConfig): + header: bool = True + sep: str = "\t" + storage_options: tp.Optional[tp.Dict[str, str]] = None + quoting: tp.Optional[int] = None + + """ + """ + + def __post_init__(self) -> None: + super().__post_init__() + + assert self.sep in [ + ",", + "\t", + ], f"only comma and tab are supported as separators, got {self.sep}" + if self.compression == "default": + self.compression = None + if self.validate_schema: + raise NotImplementedError("not supported yet for text files") + + assert self.compression in [ + None, + "zip", + "gzip", + "bz2", + "zstd", + "xz", + "tar", + ], f"unsupported compression {self.compression}" + Path(self.dataset_path).mkdir(parents=True, exist_ok=True) + + @staticmethod + def compression_to_extension(compression: tp.Optional[str]) -> str: + if compression is None: + return "" + return f".{compression}" + + @staticmethod + def separator_to_extension(sep: str) -> str: + mapping = {",": ".csv", "\t": ".tsv"} + return mapping.get(sep, ".txt") + + def write_batch( + self, + batch: BatchType, + iteration_index: tp.Sequence[int], + metadata: tp.Optional[tp.Dict[str, tp.Any]] = None, + state_checkpoint: tp.Optional[PartitionedDataMapperState] = None, + ) -> tp.List[Path]: + if batch is None or batch_length(batch) == 0: + # TODO: logger empty batch + return [] + + # TODO: reuse resolve_output logic here + try: + guid = hash_table_with_schema(batch_to_table(batch))[:20] + except Exception as e: + print(f"`hash_table_with_schema` failed : {e}") + guid = f"{uuid.uuid4()}"[:20] + + file_name = f"{guid}" + iteration_index = ( + (iteration_index,) if isinstance(iteration_index, int) else iteration_index + ) + for idx in iteration_index: + file_name += f"_{idx}" + file_name += f"{self.separator_to_extension(self.sep)}{self.compression_to_extension(self.compression)}" + + path = Path(self.dataset_path).joinpath(file_name) + + df_pd: pd.DataFrame = batch_to_pandas(batch) + + df_pd.to_csv( + path, + sep=self.sep, + header=self.header, + quoting=self.quoting, + compression=self.compression, + storage_options=self.storage_options, + index=False, + ) + + if state_checkpoint: + shard_hash = xxhash.xxh3_64_intdigest( + cloudpickle.dumps(state_checkpoint.iteration_value) + ) + with (Path(self.dataset_path) / f".text_output.{shard_hash}.state").open( + "wb" + ) as f: # filename is wrong + cloudpickle.dump(state_checkpoint, f) + + # this could be interesing + # https://arrow.apache.org/docs/python/csv.html#incremental-writing + # it'll be about x3 - x4 faster for writing but we need to handle the compression and remote storage adhoc + # TODO : Write metadata + return [path] + + def reload_state( + self, + shard: Shard, + ) -> tp.Optional[PartitionedDataMapperState]: + try: + shard_hash = xxhash.xxh3_64_intdigest(cloudpickle.dumps(shard)) + with (Path(self.dataset_path) / f".text_output.{shard_hash}.state").open( + "rb" + ) as f: # filename is wrong + return cloudpickle.load(f) + except: + return None + + +@dataclass +class TopNShard(TextShard): + """ + progress of one worker processing a file up to top-N lines + """ + + nrows: tp.Optional[int] = None + + def __iter__(self) -> tp.Iterator[str]: + lines = super().__iter__() + lines = itertools.islice(lines, 0, self.nrows) + for line in lines: + yield line + + +@dataclass +class ChunkShard(TextShard): + """ + A shard that corresponds to a file contiguous chunk. + + Args: + start (int): start byte offset of the shard + end (int): end byte offset of the shard. None if the shard is to be processed till EOF + """ + + from stopes.utils.file_chunker_utils import Chunker + + start: int = 0 + end: tp.Optional[int] = None + + @property + def input_handler(self) -> tp.ContextManager: + if self._input_handler is None: + self._input_handler = self.Chunker( + str(self.input_file), self.start, self.end + ) + return self._input_handler + + +@dataclass +class RoundRobinShard(TextShard): + """ + A shard that corresponds to a subset of lines read from the file in the round robin fashion + + Args: + nshards: Number of the shards + """ + + nshards: int = 1 + + def __iter__(self) -> tp.Iterator[str]: + if self.has_started(): + log.info( + f"Resuming from previous attempt, already processed {self._lines_cnt} lines" + ) + skipped_lines = int(self.contains_header()) + self._lines_cnt + for i, line in enumerate(iter(self._reader)): # type: ignore + if i % self.nshards == self.index: + if skipped_lines == 0: + self.line = line + self._lines_cnt += 1 + yield line + else: + skipped_lines -= 1 + + +def resolve_output( + shard: TextShard, output_file: tp.Optional[Path] = None, suffix: str = "" +) -> tp.Optional[Path]: + """ + A convenience function to help users make a standard output filename for the shard. + Recommended if user wants a more consistent sharding output naming to be used in stopes pipeline + + The output suffix calibration logic: + + First, find the proper suffix (in order of priority): + + - If the input `suffix` is given, use it + - If the `output_file` if a file with suffixes, use it + - If the `output_file` is a directory, use input_file suffix. + - If neither `output_file` nor `suffix` is given, use the input_file suffix + + In all cases, make sure the output is not compressed (even if input_file is compressed) + except the user explicitly wants it, either via `output_file` or `suffix` + + After that, prepend shard index to the output suffix + + Example: + + - ouput_file = out.txt , suffix = ".tsv.gz" , no shard --> output = out.tsv.gz + - ouput_file = out.txt , suffix = ".tsv.gz" , 2 shards --> outputs = out.0.tsv.gz , out.1.tsv.gz + - ouput_file = out.txt , suffix = "" , no shard --> output = out.txt + - ouput_file = out.tsv.gz , suffix = "" , 2 shards --> outputs = out.0.tsv.gz , out.1.tsv.gz + - ouput_file = out_dir , suffix = ".tsv.gz" , no shard --> output = out_dir / input.tsv.gz + - output_file = None, suffix = "", input = in.tsv.gz", no shard --> output = in.tsv + - output_file = None, suffix = "", input = in.tsv.gz", 2 shards --> output = in.0.tsv, in.1.tsv + - output_file = file_without_ext, suffix = "" , input = in.tsv.gz, 2 shards -> ouput = file_without_ext.0, file_without_ext.1 + + """ + # stoptes.utils.file_chunker_utils adds "_expanded.txt" to a compressed file + input_name = Path(shard.input_file).name.replace("_expanded.txt", "") + + # an intermediate file from stopes.utils.sort_files has a ".merge_sort" suffix + input_name = input_name.replace(".merge_sort", "") + + # Unless user specifies suffix with .gz or .xz, we do not compress output + input_name = input_name.replace(".gz", "").replace(".xz", "") + + in_suffix = Path(input_name).suffix # .tsv or .txt + input_stem = Path(input_name).stem + + if suffix: + out_suffix = suffix + elif output_file is None or output_file.is_dir(): + out_suffix = in_suffix + else: + out_suffix = "".join(output_file.suffixes) + + # If there are more than one shard for the file, add shard index to each output name + if shard.index is not None: + out_suffix = f".{shard.index}{out_suffix}" + + if output_file is None: + resolved_output = (Path(shard.input_file).parent / input_stem).with_suffix( + out_suffix + ) + elif output_file.is_dir(): + resolved_output = (output_file / input_stem).with_suffix(out_suffix) + elif len(output_file.suffixes) == 0: + resolved_output = output_file.with_suffix(out_suffix) + else: + # file.ext1.ext2.ext3 --> file + output_stem = output_file.name[: -len("".join(output_file.suffixes))] + resolved_output = output_file.parent / (output_stem + out_suffix) + + # Happens when suffix = "" and output_file = None and input_file is not compressed + if resolved_output.resolve() == Path(shard.input_file).resolve(): + log.warning( + f"Output file is the same as input file ({shard.input_file}). Writing is disabled" + ) + return None + return resolved_output.resolve() + + +def parse_header( + input_file: Path, + header: tp.Optional[tp.Union[bool, tp.List[str]]], + sep: tp.Optional[str], +) -> tp.Optional[tp.List[str]]: + if header is False or header is None: + return None + + assert ( + sep is not None + ), "Please provide separator input. Sharder will not guess the file format cannot be guessed at this point" + + with stopes_open(input_file) as reader: + parsed_cols = next(reader).rstrip("\n").split(sep) + + if header is True: + return parsed_cols + if isinstance(header, (list, tuple)): + return list(header) + + +def make_one_text_shard( + input_file: Path, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + nrows: tp.Optional[int] = None, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[TextShard]: + cols = parse_header(input_file, header, sep) + if nrows is not None: + return [ + TopNShard( + input_file=input_file, + columns=cols, + index=0, + sep=sep, + nrows=nrows, + filter=filter, + ) + ] + else: + return [ + TextShard( + input_file=input_file, columns=cols, index=0, sep=sep, filter=filter + ) + ] + + +def make_chunk_text_shards( + input_file: Path, + nshards: int = 1, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + cache_dir: tp.Optional[Path] = None, + nrows: tp.Optional[int] = None, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[ChunkShard]: + """ + Create multiple shards from a single file, where each share corresponds to a continuous + chunk of lines in the file. If the file is compressed, it will be decompressed into + the `cache_dir` under the new name: `input_file_expanded.txt` + """ + + from stopes.utils.file_chunker_utils import find_offsets + + if input_file.suffix in {".gz", ".xz"}: + warn_once( + "Multiple shards for compressed file is asked. Chunking the compressed file results " + "in a slow scheduler warm-up. Please give the decompressed file if possible." + ) + _cache = cache_dir if cache_dir else input_file.parent + assert Path( + _cache + ).exists(), ( + f"cache directory {_cache} not found, cannot write intermediate files" + ) + input_file = expand_if_compressed(input_file, _cache) # type: ignore + + if nrows: + offsets = find_offsets_of_lines(str(input_file), nshards, nrows) + else: + offsets = find_offsets(str(input_file), nshards) + + # Convert [pos1, pos2, pos3,...] to [(pos1, pos2), (pos2, pos3),..] + file_chunks = zip(offsets, offsets[1:]) + + return [ + ChunkShard( + input_file=input_file, + columns=parse_header(input_file, header, sep), + index=i, + sep=sep, + start=start, + end=end, + filter=filter, + ) + for i, (start, end) in enumerate(file_chunks) + ] + + +def make_roundrobin_text_shards( + input_file: Path, + nshards: int = 1, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[RoundRobinShard]: + """ + Make multiple shards from a single file where each shard correspond to all lines in the file + with certain index. For example, if there are 8 shards, shard_0 corresponds to lines 0, 8, 16,.. + """ + return [ + RoundRobinShard( + input_file=input_file, + columns=parse_header(input_file, header, sep), + index=i, + sep=sep, + nshards=nshards, + filter=filter, + ) + for i in range(nshards) + ] + + +def make_sorted_text_shards( + input_file: Path, + nshards: int = 1, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + cache_dir: tp.Optional[Path] = None, + col: tp.Optional[tp.Union[str, int]] = None, + no_duplicate: bool = False, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[TextShard]: + """ + Create shards from one file by sorting the lines after values in column `col` and divide the + sorted lines into different chunks. This algorithm requires input file to be uncompressed + before. If `no_duplicate` is True, the shard will not have duplicates + """ + + from stopes.utils.file_chunker_utils import find_offsets + + if input_file.suffix in {".gz", ".xz"}: + warn_once( + "Multiple shards for compressed file is asked. Chunking the compressed file results " + "in a slow scheduler warm-up. Please give the decompressed file if possible." + ) + _cache = cache_dir if cache_dir else input_file.parent + assert Path( + _cache + ).exists(), ( + f"cache directory {_cache} not found, cannot write intermediate files" + ) + input_file = expand_if_compressed(input_file, _cache) # type: ignore + + sorted_file = str(input_file) + ".merge_sort" + sort_file(input_file, sorted_file, col=col, sep=sep, no_duplicate=no_duplicate) + offsets = find_offsets(str(sorted_file), nshards) + + # Convert [pos1, pos2, pos3,...] to [(pos1, pos2), (pos2, pos3),..] + file_chunks = zip(offsets, offsets[1:]) + return [ + ChunkShard( + input_file=sorted_file, + sep=sep, + columns=parse_header(input_file, header, sep), + index=i, + start=start, + end=end, + filter=filter, + ) + for i, (start, end) in enumerate(file_chunks) + ] + + +def make_text_file_shards( + input: tp.Union[str, tp.Sequence, Path], + nshards: int = 1, + algo: str = "chunk", + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + cache_dir: tp.Optional[Path] = None, + filter: tp.Optional[pa.dataset.Expression] = None, + **kwargs, +) -> tp.Sequence[TextShard]: + """ + Make shards from an `input`. + + Args: + input (str, list or Path): Input to gnerate the shards. It could be: + nshards: The number of shards to generate from the input. Only applicable if `input` + is a single file. + header (bool): Whether or not the input files have headers. Default False + sep: separator for columns in the line (all files in `input` must have the same format and separator) + cache_dir (str, Path, optional): directory to cache the intermedia files (such as uncompressed input file) + filter : to apply when batching # TODO: apply them for simple row iterations for consistency + + """ + assert nshards > 0, f"invalid number of shards ({nshards})" + if isinstance(input, tp.List) or isinstance(input, ListConfig): + return [ + s + for f in input + for s in make_text_file_shards( + input=f, + nshards=1, + algo=algo, + header=header, + sep=sep, + cache_dir=cache_dir, + filter=filter, + **kwargs, + ) + ] + elif (p := Path(input)).is_dir(): # type: ignore + return [ + s + for f in p.iterdir() + for s in make_text_file_shards( + input=f, + nshards=1, + algo=algo, + header=header, + sep=sep, + cache_dir=cache_dir, + filter=filter, + **kwargs, + ) + ] + elif not p.is_file(): + return [ + s + for f in sorted(glob(str(input))) + for s in make_text_file_shards( + input=f, + nshards=1, + algo=algo, + header=header, + sep=sep, + cache_dir=cache_dir, + filter=filter, + **kwargs, + ) + ] + elif nshards == 1: + return make_one_text_shard( + p, + header, + sep, + filter=filter, + ) + elif algo == "chunk": + return make_chunk_text_shards( + p, + nshards, + header, + sep, + cache_dir, + kwargs.get("nrows"), + filter=filter, + ) + elif algo == "robin": + return make_roundrobin_text_shards( + p, + nshards, + header, + sep, + filter=filter, + ) + elif algo == "sort": + return make_sorted_text_shards( + input_file=p, + nshards=nshards, + header=header, + sep=sep, + cache_dir=cache_dir, + col=kwargs.get("col"), + no_duplicate=bool(kwargs.get("no_duplicate")), + filter=filter, + ) + + raise ValueError( + f"invalid input: input={str(input)}, nshards={nshards}, algo={algo}" + ) + + +def merge_shards(shards: tp.List[tp.Union[Path]], outdir: Path, suffix: str = ""): + """Merge the shard outputs in the order of the shard indices""" + + def get_name_no_ext(fname: str) -> str: + if len(suffix) == 0: + return fname + return fname[: -len(suffix)] + + def get_shard_idx(fname: Path) -> int: + fname_no_ext = get_name_no_ext(str(fname)) + shard_idx = fname_no_ext[fname_no_ext.rfind(".") + 1 :] + return int(shard_idx) + + if len(shards) == 1: + shutil.copy2(shards[0], outdir) + else: + fname = get_name_no_ext(shards[0].name) + outputfile = str(outdir) + "/" + (fname[: fname.rfind(".")] + suffix) + ordered_shards = sorted(shards, key=get_shard_idx) + log.info(f"Writing {len(ordered_shards)} shard outputs to {outputfile}") + + with stopes_open(outputfile, "wt") as o: + for shard_output in ordered_shards: + try: + with stopes_open(shard_output) as so: + for line in so: + o.write(line) + except Exception: + log.error(f"Error in processing {shard_output}") diff --git a/stopes/utils/shards.py b/stopes/utils/shards.py deleted file mode 100644 index 5ad2f9a..0000000 --- a/stopes/utils/shards.py +++ /dev/null @@ -1,508 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# -# Different methods to support sharding for audio files - -import functools -import glob -import itertools -import logging -import typing as tp -from dataclasses import dataclass -from pathlib import Path - -from typing_extensions import Self - -from stopes.core.utils import expand_if_compressed -from stopes.core.utils import open as stopes_open -from stopes.core.utils import sort_file - - -@functools.lru_cache(10) -def warn_once(msg: str) -> None: - """Prevents flooding stderr with the same repeated error message.""" - log.warning(msg) - - -log = logging.getLogger("stopes.speech.shards") - - -@dataclass -class Shard: - """ - input to one worker procesing a file shard for an array module. Default behaviour - is that the worker will process an entire file. - - A shard is a contextmanager object: When you enter a shard in a local job, it gives - access from the input file resource (by default via `stopes.core.utils.open()`). - - A shard is also an iterator: You lazily reads each line after entering the shard. It - will update the internal states silently, to ensure the reading can be recovered if - the job needs to be re-run. - Note that this recovery is only guaranteed within one (slurm) job or machine, and not - if the whole pipeline is re-run, because a Shard object - once created - will be sent - and kept locally to each job only. - - Args: - input_file (Path): The input file. - cols (list or bool, optional): a list of header columns. None if there is no header - sep (optional): the separator of lines. Only applicable when `cols` is not None - index: index of the shard. None if there is only one shard for the file - """ - - input_file: tp.Union[str, Path] - cols: tp.Optional[tp.List[str]] = None - sep: tp.Optional[str] = None - index: tp.Optional[int] = None - - def __post_init__(self): - """Prepare internal properties""" - - # Keep how many lines already processed. Use to re-run the job - self._lines_cnt: int = 0 - - # handle the input resource - self._input_handler: tp.Optional[tp.ContextManager] = None - self._reader: tp.Optional[tp.Iterator[str]] = None - - def __enter__(self) -> Self: - if not Path(self.input_file).exists(): - raise FileNotFoundError(self.input_file) - self._reader = self.input_handler.__enter__() - return self - - @property - def input_handler(self) -> tp.ContextManager: - if self._input_handler is None: - self._input_handler = stopes_open(self.input_file) - return self._input_handler - - def resolve_column_index(self, column_name: tp.Union[int, str]) -> int: - if isinstance(column_name, int) or column_name.isdecimal(): - return int(column_name) - assert ( - isinstance(self.cols, tp.List) and len(self.cols) > 0 - ), f"{self.input_file} has no header" - try: - return self.cols.index(column_name) - except ValueError: - raise ValueError( - f"Column {column_name} not found in header of {self.input_file}: {self.cols}" - ) - - def value(self, column_name: tp.Union[int, str]) -> str: - """Get value from a given column in the current line""" - - column_offset = self.resolve_column_index(column_name) - lines = self.line.rstrip().split(self.sep) - return lines[column_offset] - - def __iter__(self) -> tp.Iterator[str]: - """start or resume the input file consumption from the last attempt.""" - lines = iter(self._reader) # type: ignore - if self.has_started(): - log.info( - f"Resuming from previous attempt, already processed {self._lines_cnt} lines" - ) - # Skip processed lines - skipped_lines = int(self.contains_header()) + self._lines_cnt - for line in itertools.islice(lines, skipped_lines, None): - # Keep track of current line and processed lines so far - self.line = line - self._lines_cnt += 1 - yield line - - def has_started(self): - """whether the shard is already (partially) processed""" - return self._lines_cnt > 0 - - def contains_header(self) -> bool: - """whether the corresponding shard contains header""" - return bool(self.cols) and not bool(self.index) - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - self.input_handler.__exit__(exc_type, exc_val, exc_tb) - self._input_handler = None - self._reader = None - - -@dataclass -class TopNShard(Shard): - """ - progress of one worker processing a file up to top-N lines - """ - - nrows: tp.Optional[int] = None - - def __iter__(self) -> tp.Iterator[str]: - lines = super().__iter__() - lines = itertools.islice(lines, 0, self.nrows) - for line in lines: - yield line - - -@dataclass -class ChunkShard(Shard): - """ - A shard that corresponds to a file contiguous chunk. - - Args: - start (int): start byte offset of the shard - end (int): end byte offset of the shard. None if the shard is to be processed till EOF - """ - - from stopes.utils.file_chunker_utils import Chunker - - start: int = 0 - end: tp.Optional[int] = None - - @property - def input_handler(self) -> tp.ContextManager: - if self._input_handler is None: - self._input_handler = self.Chunker( - str(self.input_file), self.start, self.end - ) - return self._input_handler - - -@dataclass -class RoundRobinShard(Shard): - """ - A shard that corresponds to a subset of lines read from the file in the round robin fashion - - Args: - nshards: Number of the shards - """ - - nshards: int = 1 - - def __iter__(self) -> tp.Iterator[str]: - if self.has_started(): - log.info( - f"Resuming from previous attempt, already processed {self._lines_cnt} lines" - ) - skipped_lines = int(self.contains_header()) + self._lines_cnt - for i, line in enumerate(iter(self._reader)): # type: ignore - if i % self.nshards == self.index: - if skipped_lines == 0: - self.line = line - self._lines_cnt += 1 - yield line - else: - skipped_lines -= 1 - - -def resolve_output( - shard: Shard, output_file: tp.Optional[Path] = None, suffix: str = "" -) -> tp.Optional[Path]: - """ - A convenience function to help users make a standard output filename for the shard. - Recommended if user wants a more consistent sharding output naming to be used in stopes pipeline - - The output suffix calibration logic: - - First, find the proper suffix (in order of priority): - - - If the input `suffix` is given, use it - - If the `output_file` if a file with suffixes, use it - - If the `output_file` is a directory, use input_file suffix. - - If neither `output_file` nor `suffix` is given, use the input_file suffix - - In all cases, make sure the output is not compressed (even if input_file is compressed) - except the user explicitly wants it, either via `output_file` or `suffix` - - After that, prepend shard index to the output suffix - - Example: - - - ouput_file = out.txt , suffix = ".tsv.gz" , no shard --> output = out.tsv.gz - - ouput_file = out.txt , suffix = ".tsv.gz" , 2 shards --> outputs = out.0.tsv.gz , out.1.tsv.gz - - ouput_file = out.txt , suffix = "" , no shard --> output = out.txt - - ouput_file = out.tsv.gz , suffix = "" , 2 shards --> outputs = out.0.tsv.gz , out.1.tsv.gz - - ouput_file = out_dir , suffix = ".tsv.gz" , no shard --> output = out_dir / input.tsv.gz - - output_file = None, suffix = "", input = in.tsv.gz", no shard --> output = in.tsv - - output_file = None, suffix = "", input = in.tsv.gz", 2 shards --> output = in.0.tsv, in.1.tsv - - output_file = file_without_ext, suffix = "" , input = in.tsv.gz, 2 shards -> ouput = file_without_ext.0, file_without_ext.1 - - """ - # stoptes.utils.file_chunker_utils adds "_expanded.txt" to a compressed file - input_name = Path(shard.input_file).name.replace("_expanded.txt", "") - - # an intermediate file from stopes.utils.sort_files has a ".merge_sort" suffix - input_name = input_name.replace(".merge_sort", "") - - # Unless user specifies suffix with .gz or .xz, we do not compress output - input_name = input_name.replace(".gz", "").replace(".xz", "") - - in_suffix = Path(input_name).suffix # .tsv or .txt - input_stem = Path(input_name).stem - - if suffix: - out_suffix = suffix - elif output_file is None or output_file.is_dir(): - out_suffix = in_suffix - else: - out_suffix = "".join(output_file.suffixes) - - # If there are more than one shard for the file, add shard index to each output name - if shard.index is not None: - out_suffix = f".{shard.index}{out_suffix}" - - if output_file is None: - resolved_output = (Path(shard.input_file).parent / input_stem).with_suffix( - out_suffix - ) - elif output_file.is_dir(): - resolved_output = (output_file / input_stem).with_suffix(out_suffix) - elif len(output_file.suffixes) == 0: - resolved_output = output_file.with_suffix(out_suffix) - else: - # file.ext1.ext2.ext3 --> file - output_stem = output_file.name[: -len("".join(output_file.suffixes))] - resolved_output = output_file.parent / (output_stem + out_suffix) - - # Happens when suffix = "" and output_file = None and input_file is not compressed - if resolved_output.resolve() == Path(shard.input_file).resolve(): - log.warning( - f"Output file is the same as input file ({shard.input_file}). Writing is disabled" - ) - return None - return resolved_output.resolve() - - -def parse_header(input_file: Path, sep: str): - with stopes_open(input_file) as reader: - return next(reader).rstrip("\n").split(sep) - - -def find_offsets_of_lines(filename: str, num_chunks: int, nrows: int) -> tp.List[int]: - """ - Find the offsets of a text file that makes a total of `num_chunks` roughly equal-size chunks. - Here only the first `nrows` lines are read. This function should be used when `nrows` is - relatively small compared to the size of `filename`. - To find offsets of the entire file, please use `stopes.utils.file_chunker_utils.find_offsets()` - """ - offsets = [] - r = nrows % num_chunks - chunk_size = nrows // num_chunks - with open(filename, "r", encoding="utf-8") as f: - - # Each of the r first chunks has one more line than the rest num_chunks - r - size = chunk_size + 1 - for _ in range(r): - offsets.append(f.tell()) - [f.readline() for _ in range(size)] - - for _ in range(0, num_chunks - r): - offsets.append(f.tell()) - [f.readline() for _ in range(chunk_size)] - - offsets.append(f.tell()) - - return offsets - - -def make_one_shard( - input_file: Path, - header: bool = False, - sep: tp.Optional[str] = None, - nrows: tp.Optional[int] = None, -) -> tp.Sequence[Shard]: - if header: - assert ( - sep is not None - ), "Please provide separator input. Sharder will not guess the file format cannot be guessed at this point" - cols = parse_header(input_file, sep=sep) - else: - cols = None - if nrows: - return [TopNShard(input_file, cols, sep, 0, nrows)] - else: - return [Shard(input_file, cols, sep, 0)] - - -def make_chunk_shards( - input_file: Path, - nshards: int = 1, - header: bool = False, - sep: tp.Optional[str] = None, - cache_dir: tp.Optional[Path] = None, - nrows: tp.Optional[int] = None, -) -> tp.Sequence[ChunkShard]: - """ - Create multiple shards from a single file, where each share corresponds to a continuous - chunk of lines in the file. If the file is compressed, it will be decompressed into - the `cache_dir` under the new name: `input_file_expanded.txt` - """ - - from stopes.utils.file_chunker_utils import find_offsets - - if input_file.suffix in {".gz", ".xz"}: - warn_once( - "Multiple shards for compressed file is asked. Chunking the compressed file results " - "in a slow scheduler warm-up. Please give the decompressed file if possible." - ) - _cache = cache_dir if cache_dir else input_file.parent - assert Path( - _cache - ).exists(), ( - f"cache directory {_cache} not found, cannot write intermediate files" - ) - input_file = expand_if_compressed(input_file, _cache) # type: ignore - - if nrows: - offsets = find_offsets_of_lines(str(input_file), nshards, nrows) - else: - offsets = find_offsets(str(input_file), nshards) - - # Convert [pos1, pos2, pos3,...] to [(pos1, pos2), (pos2, pos3),..] - file_chunks = zip(offsets, offsets[1:]) - if header: - assert ( - sep is not None - ), "Please provide separator input. Sharder will not guess the file format cannot be guessed at this point" - cols = parse_header(input_file, sep=sep) - else: - cols = None - return [ - ChunkShard(input_file, cols, sep, i, start, end) - for i, (start, end) in enumerate(file_chunks) - ] - - -def make_roundrobin_shards( - input_file: Path, - nshards: int = 1, - header: bool = False, - sep: tp.Optional[str] = None, -) -> tp.Sequence[RoundRobinShard]: - """ - Make multiple shards from a single file where each shard correspond to all lines in the file - with certain index. For example, if there are 8 shards, shard_0 corresponds to lines 0, 8, 16,.. - """ - if header: - assert ( - sep is not None - ), "Please provide separator input. Sharder will not guess the file format cannot be guessed at this point" - cols = parse_header(input_file, sep=sep) - else: - cols = None - return [RoundRobinShard(input_file, cols, sep, i, nshards) for i in range(nshards)] - - -def make_sorted_shards( - input_file: Path, - nshards: int = 1, - header: bool = False, - sep: tp.Optional[str] = None, - cache_dir: tp.Optional[Path] = None, - col: tp.Optional[tp.Union[str, int]] = None, - no_duplicate: bool = False, -) -> tp.Sequence[Shard]: - """ - Create shards from one file by sorting the lines after values in column `col` and divide the - sorted lines into different chunks. This algorithm requires input file to be uncompressed - before. If `no_duplicate` is True, the shard will not have duplicates - """ - - from stopes.utils.file_chunker_utils import find_offsets - - if input_file.suffix in {".gz", ".xz"}: - warn_once( - "Multiple shards for compressed file is asked. Chunking the compressed file results " - "in a slow scheduler warm-up. Please give the decompressed file if possible." - ) - _cache = cache_dir if cache_dir else input_file.parent - assert Path( - _cache - ).exists(), ( - f"cache directory {_cache} not found, cannot write intermediate files" - ) - input_file = expand_if_compressed(input_file, _cache) # type: ignore - - sorted_file = str(input_file) + ".merge_sort" - sort_file(input_file, sorted_file, col=col, sep=sep, no_duplicate=no_duplicate) - offsets = find_offsets(str(sorted_file), nshards) - - # Convert [pos1, pos2, pos3,...] to [(pos1, pos2), (pos2, pos3),..] - file_chunks = zip(offsets, offsets[1:]) - if header: - assert ( - sep is not None - ), "Please provide separator input. Sharder will not guess the file format cannot be guessed at this point" - cols = parse_header(input_file, sep=sep) - else: - cols = None - return [ - ChunkShard(sorted_file, cols, sep, i, start, end) - for i, (start, end) in enumerate(file_chunks) - ] - - -def make_shards( - input: tp.Union[str, tp.List, Path], - nshards: int = 1, - algo: str = "chunk", - header: bool = False, - sep: tp.Optional[str] = None, - cache_dir: tp.Optional[Path] = None, - **kwargs, -) -> tp.Sequence[Shard]: - """ - Make shards from an `input`. - - Args: - input (str, list or Path): Input to gnerate the shards. It could be: - nshards: The number of shards to generate from the input. Only applicable if `input` - is a single file. - header (bool): Whether or not the input files have headers. Default False - sep: separator for columns in the line (all files in `input` must have the same format and separator) - cache_dir (str, Path, optional): directory to cache the intermedia files (such as uncompressed input file) - - - """ - assert nshards > 0, f"invalid number of shards ({nshards})" - if isinstance(input, tp.List): - return [ - s - for f in input - for s in make_shards(f, 1, algo, header, sep, cache_dir, **kwargs) - ] - elif (p := Path(input)).is_dir(): - return [ - s - for f in p.iterdir() - for s in make_shards(f, 1, algo, header, sep, cache_dir, **kwargs) - ] - elif not p.is_file(): - return [ - s - for f in glob.glob(str(input)) - for s in make_shards(f, 1, algo, header, sep, cache_dir, **kwargs) - ] - elif nshards == 1: - return make_one_shard(p, header, sep) - elif algo == "chunk": - return make_chunk_shards( - p, nshards, header, sep, cache_dir, kwargs.get("nrows") - ) - elif algo == "robin": - return make_roundrobin_shards(p, nshards, header, sep) - elif algo == "sort": - return make_sorted_shards( - p, - nshards=nshards, - header=header, - sep=sep, - cache_dir=cache_dir, - col=kwargs.get("col"), - no_duplicate=bool(kwargs.get("no_duplicate")), - ) - - raise ValueError( - f"invalid input: input={str(input)}, nshards={nshards}, algo={algo}" - ) diff --git a/stopes/utils/test_hf_shards.py b/stopes/utils/test_hf_shards.py new file mode 100644 index 0000000..aaa080d --- /dev/null +++ b/stopes/utils/test_hf_shards.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +from stopes.utils.sharding.hf_shards import HFInputConfig, HFShard + +# TODO: Hard code this to test if there are changes in HF datasets API +first_item_id = 7 + + +def test_shard_iteration(): + shard = HFShard( + filter=None, + path_or_name="Fraser/mnist-text-small", + split="test", + index=0, + num_shards=50, + trust_remote_code=True, + ) + with shard: + item = next(iter(shard)) + assert isinstance(item, dict) + assert "label" in item + assert item["label"] == first_item_id + + with shard as progress: + batch_iter = progress.to_batches(batch_size=4) + item = next(batch_iter) + assert item["label"][0].as_py() == first_item_id # type: ignore + + +def test_input_config(): + input_config = HFInputConfig( + input_file="Fraser/mnist-text-small", + split="test", + num_shards=50, + trust_remote_code=True, + ) + shards = input_config.make_shards() + first_shard = shards[0] + with first_shard: + item = next(iter(first_shard)) + assert item["label"] == first_item_id diff --git a/stopes/utils/test_json_shards.py b/stopes/utils/test_json_shards.py new file mode 100644 index 0000000..8db1db4 --- /dev/null +++ b/stopes/utils/test_json_shards.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import json +from dataclasses import asdict, is_dataclass +from enum import Enum +from typing import Any, List + +from stopes.utils.sharding.abstract_shards import BatchFormat +from stopes.utils.sharding.json_shards import JSONShard, JSONShardConfig + + +def _default_json_encoder(o: Any) -> Any: + if is_dataclass(o): + return asdict(o) + if isinstance(o, Enum): + return o.value + if callable(o): + return repr(o) + return json.JSONEncoder().default(o) + + +def test_json_shard_config(tmp_path): + ids = range(10) + data = list("helloworld") + + json_data = [{"id": i, "char": d} for i, d in zip(ids, data)] + test_file = tmp_path.joinpath("test.json") + with open(test_file, encoding="utf-8", mode="w") as o: + for item in json_data: + o.write(json.dumps(item, default=_default_json_encoder) + "\n") + + input_config = JSONShardConfig(input_file=test_file, num_shards=3) + + shards: List[JSONShard] = input_config.make_shards() # type: ignore + + # Test batch API + text = "" + for shard in shards: + for batch in shard.to_batches(batch_size=3): + text += "".join(batch["char"]) + assert text == "helloworld" + + # test __iter__ API . This API mains the context explicitly, just + # like text_shards + text = "" + for shard in shards: + with shard as shard_context: + text += "".join([item["char"] for item in iter(shard_context)]) + assert text == "helloworld" + + # Test json sharding with nrows option + input_config_nrows = JSONShardConfig(input_file=test_file, num_shards=2, nrows=5) + small_shards = input_config_nrows.make_shards() + text = "" + for shard in small_shards: # type: ignore + for batch in shard.to_batches(batch_size=3): + text += "".join(batch["char"]) + assert text == "hello" diff --git a/stopes/utils/tests/conftest.py b/stopes/utils/tests/conftest.py new file mode 100644 index 0000000..62b9286 --- /dev/null +++ b/stopes/utils/tests/conftest.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import random +import shutil +import string +import tempfile +import typing as tp +from pathlib import Path + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + + +def sanitize_categorical(df: pd.DataFrame) -> pd.DataFrame: + for col, type_ in df.dtypes.items(): + if type_ == "category": + df[col] = df[col].astype("object") + return df + + +def permutationally_equal_dataframes(one: pd.DataFrame, other: pd.DataFrame) -> bool: + assert len(one) == len(other) + assert sorted(one.columns) == sorted(other.columns) + one = sanitize_categorical(one)[sorted(one.columns)] + other = sanitize_categorical(other)[sorted(other.columns)] + + one = one.sort_values(list(one.columns)).set_index(np.arange(len(one))) + other = other.sort_values(list(other.columns)).set_index(np.arange(len(other))) + + assert (one == other).all(axis=None) + return True + + +def get_random_table(size: int, seed: int = 123) -> pa.Table: + rds = np.random.RandomState(seed) + data = { + "cat": rds.randint(0, 10, size), + "name": ["name_" + str(i) for i in range(size)], + "score": np.round(rds.randn(size), 7), + } + return pa.Table.from_pydict(data) + + +@pytest.fixture() +def single_file_dataset() -> tp.Generator[Path, None, None]: + tmpdir = tempfile.mkdtemp() + tmp_parquet_ds_path = Path(tmpdir) / "test1" + table = get_random_table(10**3) + pq.write_table(table, tmp_parquet_ds_path) + yield tmp_parquet_ds_path + shutil.rmtree(tmpdir) + + +@pytest.fixture() +def multi_partition_file_dataset() -> tp.Generator[Path, None, None]: + tmpdir = tempfile.mkdtemp() + tmp_parquet_ds_path = Path(tmpdir) / "test2" + + table = get_random_table(10**3) + pq.write_to_dataset(table, tmp_parquet_ds_path, partition_cols=["cat"]) + + yield tmp_parquet_ds_path + shutil.rmtree(tmpdir) + + +def gen_random_string(length: int) -> str: + return "".join( + random.choice(string.ascii_letters + string.digits) for n in range(length) + ) + + +def generate_random_pandas_df(size: int, seed: int = 123) -> pd.DataFrame: + np_rs = np.random.RandomState(seed) + df: tp.Dict[str, tp.Union[np.ndarray, list]] = {} + df["int_col"] = np_rs.randint(0, 200, size) + df["float_col"] = np_rs.randn(size) + + df["string_col1"] = [gen_random_string(10) for _ in range(size)] + df["string_col2"] = [gen_random_string(2) for _ in range(size)] + + df["list_int_col"] = [ + np_rs.randint(-10, 10, np_rs.randint(0, 100)) for _ in range(size) + ] + df["list_float_col"] = [ + np_rs.rand(np_rs.randint(0, 10)).astype(np.float32) for _ in range(size) + ] + df["list_float_fixed_size_col"] = [ + np_rs.rand(7).astype(np.float32) for _ in range(size) + ] + return pd.DataFrame(df) + + +def generated_partitioned_parquet_file( + path: str, size: int, n_partitions: int = 20, seed: int = 123 +) -> None: + df = generate_random_pandas_df(size, seed) + + if n_partitions > 0: + df["part_key"] = np.arange(size) % n_partitions + + table = pa.Table.from_pandas(df) + + pq.write_to_dataset( + table, + path, + partition_cols=["part_key"] if n_partitions > 0 else None, + existing_data_behavior="delete_matching", + ) diff --git a/stopes/utils/tests/test_parquet_dataloader.py b/stopes/utils/tests/test_parquet_dataloader.py index 75d960b..ad70142 100644 --- a/stopes/utils/tests/test_parquet_dataloader.py +++ b/stopes/utils/tests/test_parquet_dataloader.py @@ -5,61 +5,14 @@ # LICENSE file in the root directory of this source tree. import os -import random import shutil -import string import tempfile -import typing as tp import unittest from collections import Counter -from typing import Optional -from stopes.utils.parquet_dataloader import ParquetBasicDataLoader, np, pa, pd, pq +from conftest import generated_partitioned_parquet_file - -def gen_random_string(length: int) -> str: - return "".join( - random.choice(string.ascii_letters + string.digits) for n in range(length) - ) - - -def generate_random_pandas_df(size: int, seed: int = 123) -> pd.DataFrame: - np_rs = np.random.RandomState(seed) - df: tp.Dict[str, tp.Union[np.ndarray, list]] = {} - df["int_col"] = np_rs.randint(0, 200, size) - df["float_col"] = np_rs.randn(size) - - df["string_col1"] = [gen_random_string(10) for _ in range(size)] - df["string_col2"] = [gen_random_string(2) for _ in range(size)] - - df["list_int_col"] = [ - np_rs.randint(-10, 10, np_rs.randint(0, 100)) for _ in range(size) - ] - df["list_float_col"] = [ - np_rs.rand(np_rs.randint(0, 10)).astype(np.float32) for _ in range(size) - ] - df["list_float_fixed_size_col"] = [ - np_rs.rand(7).astype(np.float32) for _ in range(size) - ] - return pd.DataFrame(df) - - -def generated_partitioned_parquet_file( - path: str, size: int, n_partitions: int = 20, seed: int = 123 -) -> None: - df = generate_random_pandas_df(size, seed) - - if n_partitions > 0: - df["part_key"] = np.arange(size) % n_partitions - - table = pa.Table.from_pandas(df) - - pq.write_to_dataset( - table, - path, - partition_cols=["part_key"] if n_partitions > 0 else None, - existing_data_behavior="delete_matching", - ) +from stopes.utils.parquet_dataloader import ParquetBasicDataLoader, np, pa, pd class TestParquetDataloader(unittest.TestCase): diff --git a/stopes/utils/tests/test_parquet_shards.py b/stopes/utils/tests/test_parquet_shards.py new file mode 100644 index 0000000..0d55809 --- /dev/null +++ b/stopes/utils/tests/test_parquet_shards.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import shutil +import tempfile + +import pyarrow as pa +import pyarrow.dataset as ds +import pyarrow.parquet as pq +from conftest import get_random_table, permutationally_equal_dataframes + +from stopes.utils.sharding.abstract_shards import BatchFormat +from stopes.utils.sharding.parquet_shards import ( + ParquetOutputConfig, + ParquetShard, + ParquetShardingConfig, +) + + +def test_parquet_to_batches(single_file_dataset): + input_config = ParquetShardingConfig( + input_file=single_file_dataset, + fragment_group_size=1, + batch_size=11, + batch_format=BatchFormat.ARROW, + ) + shard = input_config.make_shards()[0] + with shard: + bb = next(shard.to_batches(4, batch_format=BatchFormat.ARROW)) + + bb_head = input_config.head(4) + assert isinstance(bb, pa.Table) + assert bb.equals(bb_head) + + assert len(bb.schema.names) == 3 + assert len(bb) == 4 + + +def test_parquet_sharding_config_single_file(single_file_dataset): + psc = ParquetShardingConfig( + input_file=str(single_file_dataset), + batch_size=11, + columns=["score", "cat", "name"], + batch_format=BatchFormat.ARROW, + filters_expr="(ds.field('score') < 0) & (ds.field('cat') >= 5)", + ) + + assert psc.partition_columns == [] + shards = psc.make_shards() + assert len(shards) == 1 + # all columns are present + assert set(tuple(shard.columns) for shard in shards) == set( # type: ignore + [("cat", "name", "score")] + ) + assert set(tuple(shard.partition_columns) for shard in shards) == set([()]) # type: ignore + + first_shard = shards[0] + + # working with full batch + batches = list( + first_shard.to_batches( + batch_size=None, columns=psc.columns, batch_format=psc.batch_format + ) + ) + assert len(batches) == 1 + assert [len(bb) for bb in batches] == [245] + assert [list(bb.column_names) for bb in batches] == [["cat", "name", "score"]] # type: ignore + + one_batch = batches[0] + assert one_batch["score"].to_numpy().max() < 0 # type: ignore + assert one_batch["cat"].to_numpy().min() >= 5 # type: ignore + + batches = list( + first_shard.to_batches( + batch_size=psc.batch_size, + columns=psc.columns, + batch_format=psc.batch_format, + ) + ) + assert len(batches) == 23 + assert [len(bb) for bb in batches] == [11] * 22 + [3] + assert [list(bb.column_names) for bb in batches] == 23 * [["cat", "name", "score"]] # type: ignore + + one_batch = batches[3] + assert one_batch["score"].to_numpy().max() < 0 # type: ignore + assert one_batch["cat"].to_numpy().min() >= 5 # type: ignore + + +def test_parquet_sharding_config(multi_partition_file_dataset): + psc = ParquetShardingConfig( + input_file=str(multi_partition_file_dataset), + batch_size=10, + columns=["score", "cat"], + batch_format=BatchFormat.PANDAS, + filters_expr="(ds.field('score') < 0) & (ds.field('cat') >= 5)", + ) + + assert psc.partition_columns == ["cat"] + shards = psc.make_shards() + assert len(shards) == 5 + # all columns are present + assert set(tuple(shard.columns) for shard in shards) == set( # type: ignore + [("name", "score", "cat")] + ) + assert set(tuple(shard.partition_columns) for shard in shards) == set([("cat",)]) # type: ignore + first_shard = shards[0] + batches = list( + first_shard.to_batches( + batch_size=psc.batch_size, + columns=psc.columns, + batch_format=psc.batch_format, + ) + ) + assert len(batches) == 5 + assert [len(bb) for bb in batches] == [10, 10, 10, 10, 8] + assert [list(bb.columns) for bb in batches] == 5 * [psc.columns] # type: ignore + + one_batch = batches[3] + assert one_batch["score"].max() < 0 + assert one_batch["cat"].astype("int").min() >= 5 + + # working with full batch + batches = list( + first_shard.to_batches( + batch_size=None, columns=psc.columns, batch_format=psc.batch_format + ) + ) + assert len(batches) == 1 + assert [len(bb) for bb in batches] == [48] + assert [list(bb.columns) for bb in batches] == [psc.columns] # type: ignore + + one_batch = batches[0] + assert one_batch["score"].max() < 0 + assert one_batch["cat"].astype("int").min() >= 5 + + +def test_parquet_output_config(): + tmpdir = tempfile.mkdtemp() + + poc = ParquetOutputConfig( + dataset_path=tmpdir, + validate_schema=True, + compression="gzip", + row_group_size=300, + keep_same_partitioning=False, + partition_columns=["cat"], + ) + + table = get_random_table(100) + poc.expected_schema = table.schema + + single_path_ = poc.write_batch( + table.filter(ds.field("cat") == 5), iteration_index=(111,) + ) + assert len(single_path_) == 1 + single_path = single_path_[0] + assert str(single_path).endswith("_111.parquet") + assert str(single_path.parent.name) == "cat=5" + assert permutationally_equal_dataframes( + pq.read_table(poc.dataset_path).to_pandas(), + table.filter(ds.field("cat") == 5).to_pandas(), + ) + + multi_path_ = poc.write_batch( + table.filter(ds.field("cat").isin([3, 2])), iteration_index=(222,) + ) + assert len(multi_path_) == 2 + assert all(str(single_path).endswith("_222.parquet") for single_path in multi_path_) + assert str(multi_path_[0].parent.name) == "cat=2" + assert str(multi_path_[1].parent.name) == "cat=3" + expected_table = pa.concat_tables( + [table.filter(ds.field("cat") == 5), table.filter(ds.field("cat").isin([3, 2]))] + ) + assert permutationally_equal_dataframes( + pq.read_table(poc.dataset_path).to_pandas(), expected_table.to_pandas() + ) + + try: # changing schema + poc.write_batch(table.select(["cat"]), iteration_index=(2,)) + raise Exception("should not get there") + except Exception as ex: + assert "Item has schema" in str(ex) + + shutil.rmtree(tmpdir) + + +def test_parquet_sharding_config_single_file_many_fragements( + multi_partition_file_dataset, +): + psc = ParquetShardingConfig( + input_file=str(multi_partition_file_dataset), + batch_size=11, + fragment_group_size=2, + batch_format=BatchFormat.ARROW, + filters_expr="(ds.field('score') < 0) & (ds.field('cat') >= 2)", + ) + + assert psc.partition_columns == ["cat"] + shards = psc.make_shards() + assert len(shards) == 4 + + # all columns are present + assert set(tuple(shard.columns) for shard in shards) == set( # type: ignore + [("name", "score", "cat")] + ) + assert set(tuple(shard.partition_columns) for shard in shards) == set([("cat",)]) # type: ignore + + ## full dataset + full_ds = pa.concat_tables( + [ + bb + for shard in shards + for bb in shard.to_batches( + None, columns=psc.columns, batch_format=psc.batch_format + ) + ] + ) + reload_full_ds = pq.read_table(str(multi_partition_file_dataset)).filter(psc.filter) + assert permutationally_equal_dataframes( + full_ds.to_pandas(), reload_full_ds.to_pandas() + ) + + first_shard = shards[0] + assert isinstance(first_shard, ParquetShard) + assert first_shard.nb_rows == 195 # unfiltered numbers + + # working with full batch + batches = list( + first_shard.to_batches( + batch_size=None, columns=psc.columns, batch_format=psc.batch_format + ) + ) + assert len(batches) == 1 + assert [len(bb) for bb in batches] == [94] + assert [list(bb.column_names) for bb in batches] == [["name", "score", "cat"]] # type: ignore + + one_batch = batches[0] + assert one_batch["score"].to_numpy().max() < 0 # type: ignore + assert one_batch["cat"].to_numpy().min() >= 2 # type: ignore + + batches = list( + first_shard.to_batches( + batch_size=psc.batch_size, + columns=psc.columns, + batch_format=psc.batch_format, + ) + ) + assert len(batches) == 9 + assert [len(bb) for bb in batches] == [11] * 8 + [6] + + +def test_parquet_sharding_config_single_file_many_fragements_and_nb_samples_per_group( + multi_partition_file_dataset, +): + psc = ParquetShardingConfig( + input_file=str(multi_partition_file_dataset), + batch_size=11, + nb_samples_per_group=330, + batch_format=BatchFormat.ARROW, + filters_expr="ds.field('cat') >= 1", + ) + + assert psc.partition_columns == ["cat"] + shards = psc.make_shards() + assert len(shards) == 3 + + # all columns are present + assert set(tuple(shard.columns) for shard in shards) == set( # type: ignore + [("name", "score", "cat")] + ) + assert set(tuple(shard.partition_columns) for shard in shards) == set([("cat",)]) # type: ignore + + ## full dataset + tables = [ + bb + for shard in shards + for bb in shard.to_batches( + None, columns=psc.columns, batch_format=psc.batch_format + ) + ] + assert [len(bb) for bb in tables] == [401, 400, 96] + full_ds = pa.concat_tables(tables) + reload_full_ds = pq.read_table(str(multi_partition_file_dataset)).filter(psc.filter) + assert permutationally_equal_dataframes( + full_ds.to_pandas(), reload_full_ds.to_pandas() + ) diff --git a/stopes/utils/tests/test_shards.py b/stopes/utils/tests/test_shards.py index 8e2b3eb..09bba30 100644 --- a/stopes/utils/tests/test_shards.py +++ b/stopes/utils/tests/test_shards.py @@ -3,394 +3,833 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# +# +# Different methods to support sharding for audio files -import contextlib +import functools +import io import itertools import logging -import os -import random -import re +import shutil +import typing as tp +import uuid +from dataclasses import dataclass +from glob import glob from pathlib import Path -from typing import Iterator - -import pytest - -from stopes.core.utils import open_write -from stopes.utils import shards as sh -from stopes.utils.shards import ChunkShard, RoundRobinShard, Shard - -def test_shard_context(tmp_path: Path): - tmp_file = tmp_path / "tmp.tsv.gz" - - # Test that when a shard enters, the underlying input file must exist - with pytest.raises(FileNotFoundError): - with Shard(tmp_file) as _: - pass +import cloudpickle +import numpy as np +import pandas as pd +import pyarrow as pa +import xxhash +from omegaconf.listconfig import ListConfig +from pyarrow import csv as csv_pa + +from stopes.core.utils import expand_if_compressed +from stopes.core.utils import open as stopes_open +from stopes.core.utils import sort_file +from stopes.utils.arrow_utils import hash_table_with_schema +from stopes.utils.file_chunker_utils import find_offsets_of_lines +from stopes.utils.sharding.abstract_shards import ( + BatchFormat, + BatchType, + InputShardingConfig, + OutputDatasetConfig, + PartitionedDataMapperState, + Shard, + arrow_table_to_batch, + batch_length, + batch_to_pandas, + batch_to_table, +) -def test_shard_internal_update(tmp_path: Path): - tmp_file = tmp_path / "tmp.tsv.gz" - # Test that the shard.lines is properly updated during shard iteration - line_cnt = 10 - with open_write(tmp_file, "w") as o: - [o.write(f"line{i}\n") for i in range(line_cnt)] +@functools.lru_cache(10) +def warn_once(msg: str) -> None: + """Prevents flooding stderr with the same repeated error message.""" + log.warning(msg) + + +log = logging.getLogger("stopes.speech.shards") + + +@dataclass +class TextShard(Shard): + """ + input to one worker processing a file shard for an array module. Default behaviour + is that the worker will process an entire file. + + A shard is a contextmanager object: When you enter a shard in a local job, it gives + access from the input file resource (by default via `stopes.core.utils.open()`). + + A shard is also an iterator: You lazily reads each line after entering the shard. It + will update the internal states silently, to ensure the reading can be recovered if + the job needs to be re-run. + Note that this recovery is only guaranteed within one (slurm) job or machine, and not + if the whole pipeline is re-run, because a Shard object - once created - will be sent + and kept locally to each job only. + + Args: + input_file (Path): The input file. + columns (list or bool, optional): a list of header columns. None if there is no header + sep (optional): the separator of lines. Only applicable when `cols` is not None + index: index of the shard. None if there is only one shard for the file + path_column : when not None, means the column's name (returned with to_batches()) + containing the file path from which the corresponding data is read. + If None, no extra column is added. + + """ + + input_file: tp.Union[str, Path] + columns: tp.Optional[tp.List[str]] = None + sep: tp.Optional[str] = None + index: tp.Optional[int] = None + path_column: tp.Optional[str] = None + + def __post_init__(self): + """Prepare internal properties""" + super().__post_init__() + # Keep how many lines already processed. Use to re-run the job + self._lines_cnt: int = 0 + + # handle the input resource + self._input_handler: tp.Optional[tp.ContextManager] = None + self._reader: tp.Optional[tp.Iterator[str]] = None + + def __enter__(self) -> "TextShard": + if not Path(self.input_file).exists(): + raise FileNotFoundError(self.input_file) + self._reader = self.input_handler.__enter__() + return self + + @property + def input_handler(self) -> tp.ContextManager: + if self._input_handler is None: + self._input_handler = stopes_open(self.input_file) + return self._input_handler + + def resolve_column_index(self, column_name: tp.Union[int, str]) -> int: + if isinstance(column_name, int) or column_name.isdecimal(): + return int(column_name) + assert ( + isinstance(self.columns, tp.List) and len(self.columns) > 0 + ), f"{self.input_file} has no header" + try: + return self.columns.index(column_name) + except ValueError: + raise ValueError( + f"Column {column_name} not found in header of {self.input_file}: {self.columns}" + ) + + def value(self, column_name: tp.Union[int, str]) -> str: + """Get value from a given column in the current line""" + + column_offset = self.resolve_column_index(column_name) + lines = self.line.rstrip().split(self.sep) + return lines[column_offset] + + def __iter__(self) -> tp.Iterator[str]: + """start or resume the input file consumption from the last attempt.""" + + lines = iter(self._reader) # type: ignore + if self.has_started(): + log.info( + f"Resuming from previous attempt, already processed {self._lines_cnt} lines" + ) + # Skip processed lines + skipped_lines = int(self.contains_header()) + self._lines_cnt + for line in itertools.islice(lines, skipped_lines, None): + # Keep track of current line and processed lines so far + self.line = line + self._lines_cnt += 1 + yield line + + def has_started(self): + """whether the shard is already (partially) processed""" + return self._lines_cnt > 0 + + def contains_header(self) -> bool: + """whether the corresponding shard contains header""" + return bool(self.columns) and not bool(self.index) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.input_handler.__exit__(exc_type, exc_val, exc_tb) + self._input_handler = None + self._reader = None + + def to_batches( + self, + batch_size: tp.Optional[int], + columns: tp.Optional[tp.List[str]] = None, + batch_format: BatchFormat = BatchFormat.PANDAS, + ) -> tp.Iterator[BatchType]: + assert batch_size is None or batch_size > 0 + + if columns is not None and self.columns is not None: + assert set(columns).issubset(set(self.columns)) + + if columns is None: + columns = self.columns + + read_options = csv_pa.ReadOptions( + use_threads=True, column_names=self.columns, encoding="utf8" + ) + parse_options = csv_pa.ParseOptions(delimiter=self.sep, ignore_empty_lines=True) + convert_options = csv_pa.ConvertOptions(include_columns=columns) + + old_lines_cnt, self._lines_cnt = self._lines_cnt, 0 + with self as reading_context: + stream = io.BytesIO("".join(reading_context).encode()) + table = csv_pa.read_csv( + stream, + read_options=read_options, + parse_options=parse_options, + convert_options=convert_options, + ) + stream.close() + self._lines_cnt = old_lines_cnt + + if self.path_column: + table = table.append_column( + self.path_column, + pa.DictionaryArray.from_arrays( + np.zeros(len(table), dtype=np.int32), [str(self.input_file)] + ), + ) + + if self.filter is not None: + table = table.filter(self.filter) + if len(table) > 0: + if batch_size is None: + yield arrow_table_to_batch(table, batch_format) + else: + for tt in table.to_batches(max_chunksize=batch_size): + min_table = pa.Table.from_batches([tt]) + yield arrow_table_to_batch(min_table, batch_format) + + +@dataclass +class TextShardingConfig(InputShardingConfig): + nb_shards: int = 1 + sharding_strategy: str = "chunk" + header: tp.Any = True # should be tp.Optional[tp.Union[bool, tp.List[str]]] + sep: tp.Optional[str] = None + cache_dir: tp.Optional[Path] = None + path_column: tp.Optional[str] = None + # TODO: restrict only on supported sharding strategies + # TODO: split by given number of rows + """ + - header: either bool (True meaning the presence of header in a file) or explicit list of resulting column names + - path_column : when not None, means the column's name (returned with shard.to_batches()) + containing the file path from which the corresponding data is read. If None, no extra column is added. + """ + + def __post_init__(self): + super().__post_init__() + assert self.nb_shards > 0, f"invalid number of shards ({self.nb_shards})" + assert ( + len(self.skip_n_rows_per_shard) == 0 + ), f"skipping not supported for this shard type" + + def validate(self) -> None: + # TODO: verify that files exists and are readable with provided parameters + pass + + def make_shards(self, **kwargs) -> tp.List[Shard]: + shards: tp.List[Shard] = list( + make_text_file_shards( + input=( + self.input_dataset[0] + if len(self.input_dataset) == 1 + else self.input_dataset + ), + nshards=self.nb_shards, + algo=self.sharding_strategy, + header=self.header, # type: ignore + sep=self.sep, + cache_dir=self.cache_dir, + filter=self.filter, + **kwargs, + ) + ) + shards = shards[: self.take] + if self.path_column: + for shard in shards: + shard.path_column = self.path_column # type: ignore + return shards + + +@dataclass +class TextOutputConfig(OutputDatasetConfig): + header: bool = True + sep: str = "\t" + storage_options: tp.Optional[tp.Dict[str, str]] = None + quoting: tp.Optional[int] = None + + """ + """ + + def __post_init__(self) -> None: + super().__post_init__() + + assert self.sep in [ + ",", + "\t", + ], f"only comma and tab are supported as separators, got {self.sep}" + if self.compression == "default": + self.compression = None + if self.validate_schema: + raise NotImplementedError("not supported yet for text files") + + assert self.compression in [ + None, + "zip", + "gzip", + "bz2", + "zstd", + "xz", + "tar", + ], f"unsupported compression {self.compression}" + Path(self.dataset_path).mkdir(parents=True, exist_ok=True) + + @staticmethod + def compression_to_extension(compression: tp.Optional[str]) -> str: + if compression is None: + return "" + return f".{compression}" + + @staticmethod + def separator_to_extension(sep: str) -> str: + mapping = {",": ".csv", "\t": ".tsv"} + return mapping.get(sep, ".txt") + + def write_batch( + self, + batch: BatchType, + iteration_index: tp.Sequence[int], + metadata: tp.Optional[tp.Dict[str, tp.Any]] = None, + state_checkpoint: tp.Optional[PartitionedDataMapperState] = None, + ) -> tp.List[Path]: + if batch is None or batch_length(batch) == 0: + # TODO: logger empty batch + return [] + + # TODO: reuse resolve_output logic here + try: + guid = hash_table_with_schema(batch_to_table(batch))[:20] + except Exception as e: + print(f"`hash_table_with_schema` failed : {e}") + guid = f"{uuid.uuid4()}"[:20] + + file_name = f"{guid}" + iteration_index = ( + (iteration_index,) if isinstance(iteration_index, int) else iteration_index + ) + for idx in iteration_index: + file_name += f"_{idx}" + file_name += f"{self.separator_to_extension(self.sep)}{self.compression_to_extension(self.compression)}" + + path = Path(self.dataset_path).joinpath(file_name) + + df_pd: pd.DataFrame = batch_to_pandas(batch) + + df_pd.to_csv( + path, + sep=self.sep, + header=self.header, + quoting=self.quoting, + compression=self.compression, + storage_options=self.storage_options, + index=False, + ) - shard = Shard(tmp_file) - break_point = 4 - with shard as first_pass: - for i, line in enumerate(iter(first_pass)): - assert line.rstrip() == f"line{i}" - if i == break_point: - break + if state_checkpoint: + shard_hash = xxhash.xxh3_64_intdigest( + cloudpickle.dumps(state_checkpoint.iteration_value) + ) + with (Path(self.dataset_path) / f".text_output.{shard_hash}.state").open( + "wb" + ) as f: # filename is wrong + cloudpickle.dump(state_checkpoint, f) + + # this could be interesing + # https://arrow.apache.org/docs/python/csv.html#incremental-writing + # it'll be about x3 - x4 faster for writing but we need to handle the compression and remote storage adhoc + # TODO : Write metadata + return [path] + + def reload_state( + self, + shard: Shard, + ) -> tp.Optional[PartitionedDataMapperState]: + try: + shard_hash = xxhash.xxh3_64_intdigest(cloudpickle.dumps(shard)) + with (Path(self.dataset_path) / f".text_output.{shard_hash}.state").open( + "rb" + ) as f: # filename is wrong + return cloudpickle.load(f) + except: + return None + + +@dataclass +class TopNShard(TextShard): + """ + progress of one worker processing a file up to top-N lines + """ + + nrows: tp.Optional[int] = None + + def __iter__(self) -> tp.Iterator[str]: + lines = super().__iter__() + lines = itertools.islice(lines, 0, self.nrows) + for line in lines: + yield line + + +@dataclass +class ChunkShard(TextShard): + """ + A shard that corresponds to a file contiguous chunk. + + Args: + start (int): start byte offset of the shard + end (int): end byte offset of the shard. None if the shard is to be processed till EOF + """ + + from stopes.utils.file_chunker_utils import Chunker + + start: int = 0 + end: tp.Optional[int] = None + + @property + def input_handler(self) -> tp.ContextManager: + if self._input_handler is None: + self._input_handler = self.Chunker( + str(self.input_file), self.start, self.end + ) + return self._input_handler + + +@dataclass +class RoundRobinShard(TextShard): + """ + A shard that corresponds to a subset of lines read from the file in the round robin fashion + + Args: + nshards: Number of the shards + """ + + nshards: int = 1 + + def __iter__(self) -> tp.Iterator[str]: + if self.has_started(): + log.info( + f"Resuming from previous attempt, already processed {self._lines_cnt} lines" + ) + skipped_lines = int(self.contains_header()) + self._lines_cnt + for i, line in enumerate(iter(self._reader)): # type: ignore + if i % self.nshards == self.index: + if skipped_lines == 0: + self.line = line + self._lines_cnt += 1 + yield line + else: + skipped_lines -= 1 + + +def resolve_output( + shard: TextShard, output_file: tp.Optional[Path] = None, suffix: str = "" +) -> tp.Optional[Path]: + """ + A convenience function to help users make a standard output filename for the shard. + Recommended if user wants a more consistent sharding output naming to be used in stopes pipeline + + The output suffix calibration logic: + + First, find the proper suffix (in order of priority): + + - If the input `suffix` is given, use it + - If the `output_file` if a file with suffixes, use it + - If the `output_file` is a directory, use input_file suffix. + - If neither `output_file` nor `suffix` is given, use the input_file suffix + + In all cases, make sure the output is not compressed (even if input_file is compressed) + except the user explicitly wants it, either via `output_file` or `suffix` + + After that, prepend shard index to the output suffix + + Example: + + - ouput_file = out.txt , suffix = ".tsv.gz" , no shard --> output = out.tsv.gz + - ouput_file = out.txt , suffix = ".tsv.gz" , 2 shards --> outputs = out.0.tsv.gz , out.1.tsv.gz + - ouput_file = out.txt , suffix = "" , no shard --> output = out.txt + - ouput_file = out.tsv.gz , suffix = "" , 2 shards --> outputs = out.0.tsv.gz , out.1.tsv.gz + - ouput_file = out_dir , suffix = ".tsv.gz" , no shard --> output = out_dir / input.tsv.gz + - output_file = None, suffix = "", input = in.tsv.gz", no shard --> output = in.tsv + - output_file = None, suffix = "", input = in.tsv.gz", 2 shards --> output = in.0.tsv, in.1.tsv + - output_file = file_without_ext, suffix = "" , input = in.tsv.gz, 2 shards -> ouput = file_without_ext.0, file_without_ext.1 - assert shard.has_started() - with shard as second_pass: - for i, line in enumerate(iter(second_pass)): - assert line.rstrip() == f"line{i + break_point + 1}" - assert shard._lines_cnt == line_cnt + """ + # stoptes.utils.file_chunker_utils adds "_expanded.txt" to a compressed file + input_name = Path(shard.input_file).name.replace("_expanded.txt", "") + + # an intermediate file from stopes.utils.sort_files has a ".merge_sort" suffix + input_name = input_name.replace(".merge_sort", "") + # Unless user specifies suffix with .gz or .xz, we do not compress output + input_name = input_name.replace(".gz", "").replace(".xz", "") + + in_suffix = Path(input_name).suffix # .tsv or .txt + input_stem = Path(input_name).stem + + if suffix: + out_suffix = suffix + elif output_file is None or output_file.is_dir(): + out_suffix = in_suffix + else: + out_suffix = "".join(output_file.suffixes) -@pytest.mark.parametrize( - "input_file", ["tmp.tsv", "tmp.tsv.gz", "tmp.tsv_expanded.txt"] -) -@pytest.mark.parametrize("output_file", [None, "output.txt", "outdir", "outfile"]) -@pytest.mark.parametrize("suffix", ["", ".tsv.gz"]) -def test_resolve_output(tmp_path: Path, input_file, output_file, suffix): - tmp_file = tmp_path / input_file - if output_file: - (tmp_path / "outdir").mkdir() - if output_file == "outdir": - output_file = tmp_path / "outdir" - else: - output_file = tmp_path / "outdir" / output_file - output_file.touch() - - shard1 = Shard(tmp_file) - resolved_output_file = sh.resolve_output(shard1, output_file, suffix) + # If there are more than one shard for the file, add shard index to each output name + if shard.index is not None: + out_suffix = f".{shard.index}{out_suffix}" if output_file is None: - if suffix == "" and input_file == "tmp.tsv": - assert resolved_output_file is None - elif suffix == ".tsv.gz" and input_file == "tmp.tsv.gz": - assert resolved_output_file is None - else: - assert resolved_output_file is not None - assert resolved_output_file.parent == tmp_path - output_name = "tmp.tsv" if suffix == "" else "tmp.tsv.gz" - assert resolved_output_file.name == output_name - elif suffix != "": - assert resolved_output_file is not None - assert resolved_output_file.parent == tmp_path / "outdir" - assert resolved_output_file.suffix == ".gz" - elif output_file.is_file(): - assert resolved_output_file == output_file + resolved_output = (Path(shard.input_file).parent / input_stem).with_suffix( + out_suffix + ) + elif output_file.is_dir(): + resolved_output = (output_file / input_stem).with_suffix(out_suffix) + elif len(output_file.suffixes) == 0: + resolved_output = output_file.with_suffix(out_suffix) else: - assert resolved_output_file is not None - assert resolved_output_file.parent == tmp_path / "outdir" - assert resolved_output_file.name == "tmp.tsv" - - shard2 = Shard(tmp_file, index=5) - resolved_output_shard = sh.resolve_output(shard2, output_file, suffix) - assert resolved_output_shard is not None - if suffix == ".tsv.gz": - assert resolved_output_shard.suffixes == [".5", ".tsv", ".gz"] - elif output_file is None or output_file.name == "outdir": - assert resolved_output_shard.suffixes == [".5", ".tsv"] - elif output_file.name == "outfile": - assert resolved_output_shard.suffixes == [".5"] + # file.ext1.ext2.ext3 --> file + output_stem = output_file.name[: -len("".join(output_file.suffixes))] + resolved_output = output_file.parent / (output_stem + out_suffix) + + # Happens when suffix = "" and output_file = None and input_file is not compressed + if resolved_output.resolve() == Path(shard.input_file).resolve(): + log.warning( + f"Output file is the same as input file ({shard.input_file}). Writing is disabled" + ) + return None + return resolved_output.resolve() + + +def parse_header( + input_file: Path, + header: tp.Optional[tp.Union[bool, tp.List[str]]], + sep: tp.Optional[str], +) -> tp.Optional[tp.List[str]]: + + if header is False or header is None: + return None + + assert ( + sep is not None + ), "Please provide separator input. Sharder will not guess the file format cannot be guessed at this point" + + with stopes_open(input_file) as reader: + parsed_cols = next(reader).rstrip("\n").split(sep) + + if header is True: + return parsed_cols + if isinstance(header, (list, tuple)): + return list(header) + + +def make_one_text_shard( + input_file: Path, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + nrows: tp.Optional[int] = None, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[TextShard]: + cols = parse_header(input_file, header, sep) + if nrows is not None: + return [ + TopNShard( + input_file=input_file, + columns=cols, + index=0, + sep=sep, + nrows=nrows, + filter=filter, + ) + ] else: - assert resolved_output_shard.suffixes == [".5", ".txt"] - - -@pytest.mark.parametrize("header", [False, True]) -def test_shard_values(header, tmp_path: Path): - - tmp_file = tmp_path / "tmp.tsv" - with open(tmp_file, "w") as o: - if header: - o.write("header1\theader2\n") - o.write("val1\tval2\n") - o.write("val3\tval4\n") - - output_dir = tmp_path / "output" - output_dir.mkdir() - - cols = ["header1", "header2"] if header else None - col_name = "header2" if header else 1 - - shard = Shard(tmp_file, cols=cols) - out_file = sh.resolve_output(shard, output_dir, suffix=".txt") - assert out_file is not None - assert out_file.resolve() == output_dir / "tmp.txt" - - with open(out_file, "w") as o, shard as progress: - for line in progress: - val = progress.value(col_name) # type: ignore - o.write(val) - - vals = open(out_file).readline() - assert vals == "val2val4" - - -def test_shard_headers(): - for i in range(5): - shard = Shard(Path("input.tsv"), cols=["header"], index=i) - - # The first shard should contains the header if the `cols` is passed - if i == 0: - assert shard.contains_header() - else: - assert not shard.contains_header() - - -def test_make_one_shard(tmp_path: Path): - with open(tmp_path / "tmp.tsv", "w") as o: - o.write("header1;heade2;header3\n") - o.write("1;2;3\n") - - shards = sh.make_one_shard(tmp_path / "tmp.tsv", header=True, sep=";") - assert len(shards) == 1 - with shards[0] as progress: - next(iter(progress)) - col = progress.value("header3") - assert int(col) == 3 - - -def test_one_shard_with_nrows(tmp_path: Path): - with open(tmp_path / "tmp.tsv", "w") as o: - o.write("header1|heade2|header3\n") - [o.write("col1|col2|col3\n") for _ in range(10)] - - shards = sh.make_one_shard(tmp_path / "tmp.tsv", header=True, sep="|", nrows=5) - with shards[0] as shard: - col = shard.resolve_column_index("header3") - assert col == 2 - lines = list(iter(shard)) - assert len(lines) == 5 - assert lines[0].rstrip().split("|")[col] == "col3" - - -def test_make_shards_from_file_list(tmp_path: Path): - input_files = [] - for i in range(5): - input_file = tmp_path / f"file{i}.tsv" - input_file.touch() - input_files.append(input_file) - - # `nshards`` should have no impact here - shards = sh.make_shards(input_files, nshards=50) - assert len(shards) == len(input_files) - for i, p in enumerate(shards): - assert Path(p.input_file).name == f"file{i}.tsv" + return [ + TextShard( + input_file=input_file, columns=cols, index=0, sep=sep, filter=filter + ) + ] -@pytest.mark.parametrize("zip", [False, True]) -@pytest.mark.parametrize("header", [False, True]) -def test_make_chunk_shards_from_single_file(tmp_path: Path, zip: bool, header: bool): - if zip: - tmp_file = tmp_path / "tmp.tsv.gz" - else: - tmp_file = tmp_path / "tmp.tsv" - with open_write(tmp_file) as o: - if header: - o.write("header1|header2\n") - [o.write(f"line{i}|val{i}\n") for i in range(23)] - out_dir = tmp_path / "fakedir" - out_dir.mkdir() - shards = sh.make_shards( - tmp_file, nshards=5, header=header, sep="|", cache_dir=out_dir - ) +def make_chunk_text_shards( + input_file: Path, + nshards: int = 1, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + cache_dir: tp.Optional[Path] = None, + nrows: tp.Optional[int] = None, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[ChunkShard]: + """ + Create multiple shards from a single file, where each share corresponds to a continuous + chunk of lines in the file. If the file is compressed, it will be decompressed into + the `cache_dir` under the new name: `input_file_expanded.txt` + """ + + from stopes.utils.file_chunker_utils import find_offsets + + if input_file.suffix in {".gz", ".xz"}: + warn_once( + "Multiple shards for compressed file is asked. Chunking the compressed file results " + "in a slow scheduler warm-up. Please give the decompressed file if possible." + ) + _cache = cache_dir if cache_dir else input_file.parent + assert Path( + _cache + ).exists(), ( + f"cache directory {_cache} not found, cannot write intermediate files" + ) + input_file = expand_if_compressed(input_file, _cache) # type: ignore - # Change filename to reflect fairseq internal decompression - if zip: - tmp_file = out_dir / (tmp_file.name + "_expanded.txt") - - if header: - cols = ["header1", "header2"] - col_name = "header2" - expected_shards = [ - ChunkShard(tmp_file, cols=cols, sep="|", index=0, start=0, end=60), - ChunkShard(tmp_file, cols=cols, sep="|", index=1, start=60, end=126), - ChunkShard(tmp_file, cols=cols, sep="|", index=2, start=126, end=178), - ChunkShard(tmp_file, cols=cols, sep="|", index=3, start=178, end=243), - ChunkShard(tmp_file, cols=cols, sep="|", index=4, start=243, end=295), - ] - else: - col_name = 1 # type: ignore - expected_shards = [ - ChunkShard(tmp_file, sep="|", index=0, start=0, end=66), - ChunkShard(tmp_file, sep="|", index=1, start=66, end=123), - ChunkShard(tmp_file, sep="|", index=2, start=123, end=175), - ChunkShard(tmp_file, sep="|", index=3, start=175, end=227), - ChunkShard(tmp_file, sep="|", index=4, start=227, end=279), - ] - assert len(shards) == len(expected_shards) - - line_idx = 0 - for i in range(5): - assert shards[i] == expected_shards[i] - with shards[i] as progress: - for _ in progress: - assert progress.value(col_name) == f"val{line_idx}" - line_idx += 1 - - -@pytest.mark.parametrize("zip", [False, True]) -@pytest.mark.parametrize("header", [False, True]) -def test_make_chunk_shards_with_nrows(tmp_path: Path, zip: bool, header: bool): - if zip: - tmp_file = tmp_path / "tmp.tsv.gz" + if nrows: + offsets = find_offsets_of_lines(str(input_file), nshards, nrows) else: - tmp_file = tmp_path / "tmp.tsv" - with open_write(tmp_file) as o: - if header: - o.write("header1|header2\n") - [o.write(f"line{i}|{i}\n") for i in range(50)] - # Change filename to reflect fairseq internal decompression - shards = sh.make_shards(tmp_file, nshards=5, header=header, sep="|", nrows=23) - - if zip: - tmp_filename = tmp_file.name + "_expanded.txt" - else: - tmp_filename = tmp_file.name - real_input = tmp_file.parent / tmp_filename - if header: - cols = ["header1", "header2"] - col_name = "header2" - expected_shards = [ - ChunkShard(real_input, cols=cols, sep="|", index=0, start=0, end=48), - ChunkShard(real_input, cols=cols, sep="|", index=1, start=48, end=88), - ChunkShard(real_input, cols=cols, sep="|", index=2, start=88, end=136), - ChunkShard(real_input, cols=cols, sep="|", index=3, start=136, end=176), - ChunkShard(real_input, cols=cols, sep="|", index=4, start=176, end=216), - ] - else: - col_name = 1 # type: ignore[assignment] - expected_shards = [ - ChunkShard(real_input, sep="|", index=0, start=0, end=40), - ChunkShard(real_input, sep="|", index=1, start=40, end=80), - ChunkShard(real_input, sep="|", index=2, start=80, end=130), - ChunkShard(real_input, sep="|", index=3, start=130, end=170), - ChunkShard(real_input, sep="|", index=4, start=170, end=210), - ] - - assert len(shards) == len(expected_shards) - line_idx = 0 - for i in range(5): - assert shards[i] == expected_shards[i] - with shards[i] as shard: - for _ in iter(shard): - assert int(shard.value(col_name)) == line_idx - line_idx += 1 - assert line_idx == (22 if header else 23) - - -@pytest.mark.parametrize("header", [False, True]) -def test_make_robin_shards(tmp_path: Path, header: bool): - nshards = 5 - tmp_file = tmp_path / "tmp.tsv.gz" - with open_write(tmp_file) as o: - if header: - o.write("header1\theader2\n") - [o.write(f"line\t{i}\n") for i in range(23)] - sep = "\t" if header else None - shards = sh.make_shards( - tmp_file, nshards=nshards, algo="robin", header=header, sep=sep - ) - if header: - cols = ["header1", "header2"] - col_name = "header2" - expected_shards = [ - RoundRobinShard(tmp_file, cols=cols, sep=sep, index=i, nshards=nshards) - for i in range(nshards) - ] - expected_vals = [ - [4, 9, 14, 19], - [0, 5, 10, 15, 20], - [1, 6, 11, 16, 21], - [2, 7, 12, 17, 22], - [3, 8, 13, 18], + offsets = find_offsets(str(input_file), nshards) + + # Convert [pos1, pos2, pos3,...] to [(pos1, pos2), (pos2, pos3),..] + file_chunks = zip(offsets, offsets[1:]) + + return [ + ChunkShard( + input_file=input_file, + columns=parse_header(input_file, header, sep), + index=i, + sep=sep, + start=start, + end=end, + filter=filter, + ) + for i, (start, end) in enumerate(file_chunks) + ] + + +def make_roundrobin_text_shards( + input_file: Path, + nshards: int = 1, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[RoundRobinShard]: + """ + Make multiple shards from a single file where each shard correspond to all lines in the file + with certain index. For example, if there are 8 shards, shard_0 corresponds to lines 0, 8, 16,.. + """ + return [ + RoundRobinShard( + input_file=input_file, + columns=parse_header(input_file, header, sep), + index=i, + sep=sep, + nshards=nshards, + filter=filter, + ) + for i in range(nshards) + ] + + +def make_sorted_text_shards( + input_file: Path, + nshards: int = 1, + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + cache_dir: tp.Optional[Path] = None, + col: tp.Optional[tp.Union[str, int]] = None, + no_duplicate: bool = False, + filter: tp.Optional[pa.dataset.Expression] = None, +) -> tp.Sequence[TextShard]: + """ + Create shards from one file by sorting the lines after values in column `col` and divide the + sorted lines into different chunks. This algorithm requires input file to be uncompressed + before. If `no_duplicate` is True, the shard will not have duplicates + """ + + from stopes.utils.file_chunker_utils import find_offsets + + if input_file.suffix in {".gz", ".xz"}: + warn_once( + "Multiple shards for compressed file is asked. Chunking the compressed file results " + "in a slow scheduler warm-up. Please give the decompressed file if possible." + ) + _cache = cache_dir if cache_dir else input_file.parent + assert Path( + _cache + ).exists(), ( + f"cache directory {_cache} not found, cannot write intermediate files" + ) + input_file = expand_if_compressed(input_file, _cache) # type: ignore + + sorted_file = str(input_file) + ".merge_sort" + sort_file(input_file, sorted_file, col=col, sep=sep, no_duplicate=no_duplicate) + offsets = find_offsets(str(sorted_file), nshards) + + # Convert [pos1, pos2, pos3,...] to [(pos1, pos2), (pos2, pos3),..] + file_chunks = zip(offsets, offsets[1:]) + return [ + ChunkShard( + input_file=sorted_file, + sep=sep, + columns=parse_header(input_file, header, sep), + index=i, + start=start, + end=end, + filter=filter, + ) + for i, (start, end) in enumerate(file_chunks) + ] + + +def make_text_file_shards( + input: tp.Union[str, tp.Sequence, Path], + nshards: int = 1, + algo: str = "chunk", + header: tp.Optional[tp.Union[bool, tp.List[str]]] = False, + sep: tp.Optional[str] = None, + cache_dir: tp.Optional[Path] = None, + filter: tp.Optional[pa.dataset.Expression] = None, + **kwargs, +) -> tp.Sequence[TextShard]: + """ + Make shards from an `input`. + + Args: + input (str, list or Path): Input to gnerate the shards. It could be: + nshards: The number of shards to generate from the input. Only applicable if `input` + is a single file. + header (bool): Whether or not the input files have headers. Default False + sep: separator for columns in the line (all files in `input` must have the same format and separator) + cache_dir (str, Path, optional): directory to cache the intermedia files (such as uncompressed input file) + filter : to apply when batching # TODO: apply them for simple row iterations for consistency + + """ + assert nshards > 0, f"invalid number of shards ({nshards})" + if isinstance(input, tp.List) or isinstance(input, ListConfig): + return [ + s + for f in input + for s in make_text_file_shards( + input=f, + nshards=1, + algo=algo, + header=header, + sep=sep, + cache_dir=cache_dir, + filter=filter, + **kwargs, + ) ] - else: - col_name = 1 # type: ignore - expected_shards = [ - RoundRobinShard(tmp_file, index=i, nshards=nshards) for i in range(nshards) + elif (p := Path(input)).is_dir(): # type: ignore + return [ + s + for f in p.iterdir() + for s in make_text_file_shards( + input=f, + nshards=1, + algo=algo, + header=header, + sep=sep, + cache_dir=cache_dir, + filter=filter, + **kwargs, + ) ] - expected_vals = [ - [0, 5, 10, 15, 20], - [1, 6, 11, 16, 21], - [2, 7, 12, 17, 22], - [3, 8, 13, 18], - [4, 9, 14, 19], + elif not p.is_file(): + return [ + s + for f in sorted(glob(str(input))) + for s in make_text_file_shards( + input=f, + nshards=1, + algo=algo, + header=header, + sep=sep, + cache_dir=cache_dir, + filter=filter, + **kwargs, + ) ] - assert shards == expected_shards - if header: - assert shards[0].contains_header() - assert all([not x.contains_header() for x in shards[1:]]) - - for i in range(5): - with shards[i] as progress: - vals = [int(progress.value(col_name)) for _ in progress] - assert list(vals) == expected_vals[i] - - -@pytest.mark.parametrize("header", [False, True]) -def test_make_sorted_shards(tmp_path: Path, header: bool): - import string - - tmp_file = tmp_path / "tmp.tsv" - with open_write(tmp_file) as o: - if header: - o.write("header1\theader2\n") - chs = list(string.ascii_lowercase) - random.shuffle(chs) - [o.write(f"a really long line\t{i}\n") for i in chs] - - col = "header2" if header else 1 - shards = sh.make_shards( - tmp_file, nshards=5, algo="sort", header=header, sep="\t", col=col - ) - sorted_file = str(tmp_file) + ".merge_sort" - assert Path(sorted_file).exists() - with open(sorted_file) as f: - lines = iter(f) - if header: - lines = itertools.islice(iter(f), 1, None) - for line, ch in zip(lines, string.ascii_lowercase): - assert line == f"a really long line\t{ch}\n" - assert len(shards) == 5 - if header: - assert shards[0] == ChunkShard( - sorted_file, - sep="\t", - cols=["header1", "header2"], - index=0, - start=0, - end=121, + elif nshards == 1: + return make_one_text_shard( + p, + header, + sep, + filter=filter, + ) + elif algo == "chunk": + return make_chunk_text_shards( + p, + nshards, + header, + sep, + cache_dir, + kwargs.get("nrows"), + filter=filter, + ) + elif algo == "robin": + return make_roundrobin_text_shards( + p, + nshards, + header, + sep, + filter=filter, + ) + elif algo == "sort": + return make_sorted_text_shards( + input_file=p, + nshards=nshards, + header=header, + sep=sep, + cache_dir=cache_dir, + col=kwargs.get("col"), + no_duplicate=bool(kwargs.get("no_duplicate")), + filter=filter, ) - else: - assert shards[0] == ChunkShard(sorted_file, sep="\t", index=0, start=0, end=126) + raise ValueError( + f"invalid input: input={str(input)}, nshards={nshards}, algo={algo}" + ) -def test_make_shards_from_glob(tmp_path: Path): - (tmp_path / "file1.tsv").touch() - (tmp_path / "file2.tsv").touch() - shards = list(sh.make_shards(tmp_path / "file*.tsv", cache_dir=Path("fakedir"))) - assert len(shards) == 2 - shards.sort(key=lambda x: Path(x.input_file).name) - assert Path(shards[0].input_file).name == "file1.tsv" - assert Path(shards[1].input_file).name == "file2.tsv" +def merge_shards(shards: tp.List[tp.Union[Path]], outdir: Path, suffix: str = ""): + """Merge the shard outputs in the order of the shard indices""" + def get_name_no_ext(fname: str) -> str: + if len(suffix) == 0: + return fname + return fname[: -len(suffix)] -@contextlib.contextmanager -def assert_warns(caplog, *, match: str) -> Iterator[None]: - caplog.clear() - sh.warn_once.cache_clear() + def get_shard_idx(fname: Path) -> int: + fname_no_ext = get_name_no_ext(str(fname)) + shard_idx = fname_no_ext[fname_no_ext.rfind(".") + 1 :] + return int(shard_idx) - with caplog.at_level(logging.WARN): - yield - assert len(caplog.messages) == 1 - assert re.match(match, caplog.messages[0]) - caplog.clear() + if len(shards) == 1: + shutil.copy2(shards[0], outdir) + else: + fname = get_name_no_ext(shards[0].name) + outputfile = str(outdir) + "/" + (fname[: fname.rfind(".")] + suffix) + ordered_shards = sorted(shards, key=get_shard_idx) + log.info(f"Writing {len(ordered_shards)} shard outputs to {outputfile}") + + with stopes_open(outputfile, "wt") as o: + for shard_output in ordered_shards: + try: + with stopes_open(shard_output) as so: + for line in so: + o.write(line) + except Exception as err: + log.error(f"Error in processing {shard_output}") diff --git a/stopes/utils/tests/test_text_shards.py b/stopes/utils/tests/test_text_shards.py new file mode 100644 index 0000000..498f8c3 --- /dev/null +++ b/stopes/utils/tests/test_text_shards.py @@ -0,0 +1,644 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import itertools +import logging +import random +import re +from pathlib import Path +from typing import Iterator, List + +import pandas as pd +import pytest + +from stopes.core.utils import open_write +from stopes.utils.sharding import text_shards as sh +from stopes.utils.sharding.abstract_shards import BatchFormat, concat_batches +from stopes.utils.sharding.text_shards import ( + ChunkShard, + RoundRobinShard, + TextOutputConfig, + TextShard, + TextShardingConfig, +) +from stopes.utils.tests.test_parquet_shards import ( + get_random_table, + permutationally_equal_dataframes, +) + + +def test_shard_context(tmp_path: Path): + tmp_file = tmp_path / "tmp.tsv.gz" + + # Test that when a shard enters, the underlying input file must exist + with pytest.raises(FileNotFoundError): + with TextShard(input_file=tmp_file, filter=None) as _: + pass + + +def test_shard_internal_update(tmp_path: Path): + tmp_file = tmp_path / "tmp.tsv.gz" + # Test that the shard.lines is properly updated during shard iteration + line_cnt = 10 + with open_write(tmp_file, "w") as o: + [o.write(f"line{i}\n") for i in range(line_cnt)] + + shard = TextShard(input_file=tmp_file, filter=None) + break_point = 4 + with shard as first_pass: + for i, line in enumerate(iter(first_pass)): + assert line.rstrip() == f"line{i}" + if i == break_point: + break + + assert shard.has_started() + with shard as second_pass: + for i, line in enumerate(iter(second_pass)): + assert line.rstrip() == f"line{i + break_point + 1}" + assert shard._lines_cnt == line_cnt + + +@pytest.mark.parametrize( + "input_file", ["tmp.tsv", "tmp.tsv.gz", "tmp.tsv_expanded.txt"] +) +@pytest.mark.parametrize("output_file", [None, "output.txt", "outdir", "outfile"]) +@pytest.mark.parametrize("suffix", ["", ".tsv.gz"]) +def test_resolve_output(tmp_path: Path, input_file, output_file, suffix): + tmp_file = tmp_path / input_file + if output_file: + (tmp_path / "outdir").mkdir() + if output_file == "outdir": + output_file = tmp_path / "outdir" + else: + output_file = tmp_path / "outdir" / output_file + output_file.touch() + + shard1 = TextShard(input_file=tmp_file, filter=None) + resolved_output_file = sh.resolve_output(shard1, output_file, suffix) + + if output_file is None: + if suffix == "" and input_file == "tmp.tsv": + assert resolved_output_file is None + elif suffix == ".tsv.gz" and input_file == "tmp.tsv.gz": + assert resolved_output_file is None + else: + assert resolved_output_file is not None + assert resolved_output_file.parent == tmp_path + output_name = "tmp.tsv" if suffix == "" else "tmp.tsv.gz" + assert resolved_output_file.name == output_name + elif suffix != "": + assert resolved_output_file is not None + assert resolved_output_file.parent == tmp_path / "outdir" + assert resolved_output_file.suffix == ".gz" + elif output_file.is_file(): + assert resolved_output_file == output_file + else: + assert resolved_output_file is not None + assert resolved_output_file.parent == tmp_path / "outdir" + assert resolved_output_file.name == "tmp.tsv" + + shard2 = TextShard(input_file=tmp_file, index=5, filter=None) + resolved_output_shard = sh.resolve_output(shard2, output_file, suffix) + assert resolved_output_shard is not None + if suffix == ".tsv.gz": + assert resolved_output_shard.suffixes == [".5", ".tsv", ".gz"] + elif output_file is None or output_file.name == "outdir": + assert resolved_output_shard.suffixes == [".5", ".tsv"] + elif output_file.name == "outfile": + assert resolved_output_shard.suffixes == [".5"] + else: + assert resolved_output_shard.suffixes == [".5", ".txt"] + + +@pytest.mark.parametrize("header", [False, True]) +def test_shard_values(header, tmp_path: Path): + tmp_file = tmp_path / "tmp.tsv" + with open(tmp_file, "w") as o: + if header: + o.write("header1\theader2\n") + o.write("val1\tval2\n") + o.write("val3\tval4\n") + + output_dir = tmp_path / "output" + output_dir.mkdir() + + cols = ["header1", "header2"] if header else None + col_name = "header2" if header else 1 + + shard = TextShard(input_file=tmp_file, columns=cols, filter=None) + out_file = sh.resolve_output(shard, output_dir, suffix=".txt") + assert out_file is not None + assert out_file.resolve() == output_dir / "tmp.txt" + + with open(out_file, "w") as o, shard as progress: + for line in progress: + val = progress.value(col_name) # type: ignore + o.write(val) + + vals = open(out_file).readline() + assert vals == "val2val4" + + +def test_shard_headers(): + for i in range(5): + shard = TextShard( + input_file=Path("input.tsv"), columns=["header"], index=i, filter=None + ) + + # The first shard should contains the header if the `cols` is passed + if i == 0: + assert shard.contains_header() + else: + assert not shard.contains_header() + + +def test_make_one_shard(tmp_path: Path): + with open(tmp_path / "tmp.tsv", "w") as o: + o.write("header1;heade2;header3\n") + o.write("1;2;3\n") + + shards = sh.make_one_text_shard(tmp_path / "tmp.tsv", header=True, sep=";") + assert len(shards) == 1 + with shards[0] as progress: + next(iter(progress)) + col = progress.value("header3") + assert int(col) == 3 + + +def test_one_shard_with_nrows(tmp_path: Path): + with open(tmp_path / "tmp.tsv", "w") as o: + o.write("header1|heade2|header3\n") + [o.write("col1|col2|col3\n") for _ in range(10)] + + shards = sh.make_one_text_shard(tmp_path / "tmp.tsv", header=True, sep="|", nrows=5) + with shards[0] as shard: + col = shard.resolve_column_index("header3") + assert col == 2 + lines = list(iter(shard)) + assert len(lines) == 5 + assert lines[0].rstrip().split("|")[col] == "col3" + + +def test_make_shards_from_file_list(tmp_path: Path): + input_files = [] + for i in range(5): + input_file = tmp_path / f"file{i}.tsv" + input_file.touch() + input_files.append(input_file) + + # `nshards`` should have no impact here + shards = sh.make_text_file_shards(input_files, nshards=50) + assert len(shards) == len(input_files) + for i, p in enumerate(shards): + assert Path(p.input_file).name == f"file{i}.tsv" + + +@pytest.mark.parametrize("zip", [False, True]) +@pytest.mark.parametrize("header", [False, True]) +def test_make_chunk_shards_from_single_file(tmp_path: Path, zip: bool, header: bool): + if zip: + tmp_file = tmp_path / "tmp.tsv.gz" + else: + tmp_file = tmp_path / "tmp.tsv" + with open_write(tmp_file) as o: + if header: + o.write("header1|header2\n") + [o.write(f"line{i}|val{i}\n") for i in range(23)] + out_dir = tmp_path / "fakedir" + out_dir.mkdir() + shards = sh.make_text_file_shards( + tmp_file, nshards=5, header=header, sep="|", cache_dir=out_dir + ) + + # Change filename to reflect fairseq internal decompression + if zip: + tmp_file = out_dir / (tmp_file.name + "_expanded.txt") + + if header: + cols = ["header1", "header2"] + col_name = "header2" + expected_shards = [ + ChunkShard( + input_file=tmp_file, + columns=cols, + sep="|", + index=0, + start=0, + end=60, + filter=None, + ), + ChunkShard( + input_file=tmp_file, + columns=cols, + sep="|", + index=1, + start=60, + end=126, + filter=None, + ), + ChunkShard( + input_file=tmp_file, + columns=cols, + sep="|", + index=2, + start=126, + end=178, + filter=None, + ), + ChunkShard( + input_file=tmp_file, + columns=cols, + sep="|", + index=3, + start=178, + end=243, + filter=None, + ), + ChunkShard( + input_file=tmp_file, + columns=cols, + sep="|", + index=4, + start=243, + end=295, + filter=None, + ), + ] + else: + col_name = 1 # type: ignore + expected_shards = [ + ChunkShard( + input_file=tmp_file, sep="|", index=0, start=0, end=66, filter=None + ), + ChunkShard( + input_file=tmp_file, sep="|", index=1, start=66, end=123, filter=None + ), + ChunkShard( + input_file=tmp_file, sep="|", index=2, start=123, end=175, filter=None + ), + ChunkShard( + input_file=tmp_file, sep="|", index=3, start=175, end=227, filter=None + ), + ChunkShard( + input_file=tmp_file, sep="|", index=4, start=227, end=279, filter=None + ), + ] + assert len(shards) == len(expected_shards) + + line_idx = 0 + for i in range(5): + assert shards[i] == expected_shards[i] + with shards[i] as progress: + for _ in progress: + assert progress.value(col_name) == f"val{line_idx}" + line_idx += 1 + + +@pytest.mark.parametrize("zip", [False, True]) +@pytest.mark.parametrize("header", [False, True]) +def test_make_chunk_shards_with_nrows(tmp_path: Path, zip: bool, header: bool): + if zip: + tmp_file = tmp_path / "tmp.tsv.gz" + else: + tmp_file = tmp_path / "tmp.tsv" + with open_write(tmp_file) as o: + if header: + o.write("header1|header2\n") + [o.write(f"line{i}|{i}\n") for i in range(50)] + # Change filename to reflect fairseq internal decompression + shards = sh.make_text_file_shards( + tmp_file, nshards=5, header=header, sep="|", nrows=23 + ) + + if zip: + tmp_filename = tmp_file.name + "_expanded.txt" + else: + tmp_filename = tmp_file.name + real_input = tmp_file.parent / tmp_filename + if header: + cols = ["header1", "header2"] + col_name = "header2" + expected_shards = [ + ChunkShard( + input_file=real_input, + columns=cols, + sep="|", + index=0, + start=0, + end=48, + filter=None, + ), + ChunkShard( + input_file=real_input, + columns=cols, + sep="|", + index=1, + start=48, + end=88, + filter=None, + ), + ChunkShard( + input_file=real_input, + columns=cols, + sep="|", + index=2, + start=88, + end=136, + filter=None, + ), + ChunkShard( + input_file=real_input, + columns=cols, + sep="|", + index=3, + start=136, + end=176, + filter=None, + ), + ChunkShard( + input_file=real_input, + columns=cols, + sep="|", + index=4, + start=176, + end=216, + filter=None, + ), + ] + else: + col_name = 1 # type: ignore[assignment] + expected_shards = [ + ChunkShard( + input_file=real_input, sep="|", index=0, start=0, end=40, filter=None + ), + ChunkShard( + input_file=real_input, sep="|", index=1, start=40, end=80, filter=None + ), + ChunkShard( + input_file=real_input, sep="|", index=2, start=80, end=130, filter=None + ), + ChunkShard( + input_file=real_input, sep="|", index=3, start=130, end=170, filter=None + ), + ChunkShard( + input_file=real_input, sep="|", index=4, start=170, end=210, filter=None + ), + ] + + assert len(shards) == len(expected_shards) + line_idx = 0 + for i in range(5): + assert shards[i] == expected_shards[i] + with shards[i] as shard: + for _ in iter(shard): + assert int(shard.value(col_name)) == line_idx + line_idx += 1 + assert line_idx == (22 if header else 23) + + +@pytest.mark.parametrize("header", [False, True]) +def test_make_robin_shards(tmp_path: Path, header: bool): + nshards = 5 + tmp_file = tmp_path / "tmp.tsv.gz" + with open_write(tmp_file) as o: + if header: + o.write("header1\theader2\n") + [o.write(f"line\t{i}\n") for i in range(23)] + sep = "\t" if header else None + shards = sh.make_text_file_shards( + tmp_file, nshards=nshards, algo="robin", header=header, sep=sep + ) + if header: + cols = ["header1", "header2"] + col_name = "header2" + expected_shards = [ + RoundRobinShard( + input_file=tmp_file, + columns=cols, + sep=sep, + index=i, + nshards=nshards, + filter=None, + ) + for i in range(nshards) + ] + expected_vals = [ + [4, 9, 14, 19], + [0, 5, 10, 15, 20], + [1, 6, 11, 16, 21], + [2, 7, 12, 17, 22], + [3, 8, 13, 18], + ] + else: + col_name = 1 # type: ignore + expected_shards = [ + RoundRobinShard(input_file=tmp_file, index=i, nshards=nshards, filter=None) + for i in range(nshards) + ] + expected_vals = [ + [0, 5, 10, 15, 20], + [1, 6, 11, 16, 21], + [2, 7, 12, 17, 22], + [3, 8, 13, 18], + [4, 9, 14, 19], + ] + assert shards == expected_shards + if header: + assert shards[0].contains_header() + assert all([not x.contains_header() for x in shards[1:]]) + + for i in range(5): + with shards[i] as progress: + vals = [int(progress.value(col_name)) for _ in progress] + assert list(vals) == expected_vals[i] + + +@pytest.mark.parametrize("header", [False, True]) +def test_make_sorted_shards(tmp_path: Path, header: bool): + import string + + tmp_file = tmp_path / "tmp.tsv" + with open_write(tmp_file) as o: + if header: + o.write("header1\theader2\n") + chs = list(string.ascii_lowercase) + random.shuffle(chs) + [o.write(f"a really long line\t{i}\n") for i in chs] + + col = "header2" if header else 1 + shards = sh.make_text_file_shards( + tmp_file, nshards=5, algo="sort", header=header, sep="\t", col=col, filter=None + ) + sorted_file = str(tmp_file) + ".merge_sort" + assert Path(sorted_file).exists() + with open(sorted_file) as f: + lines = iter(f) + if header: + lines = itertools.islice(iter(f), 1, None) + for line, ch in zip(lines, string.ascii_lowercase): + assert line == f"a really long line\t{ch}\n" + assert len(shards) == 5 + if header: + assert shards[0] == ChunkShard( + input_file=sorted_file, + sep="\t", + columns=["header1", "header2"], + index=0, + start=0, + end=121, + filter=None, + ) + else: + assert shards[0] == ChunkShard( + input_file=sorted_file, sep="\t", index=0, start=0, end=126, filter=None + ) + + +def test_make_shards_from_glob(tmp_path: Path): + (tmp_path / "file1.tsv").touch() + (tmp_path / "file2.tsv").touch() + + shards = list( + sh.make_text_file_shards(tmp_path / "file*.tsv", cache_dir=Path("fakedir")) + ) + assert len(shards) == 2 + shards.sort(key=lambda x: Path(x.input_file).name) + assert Path(shards[0].input_file).name == "file1.tsv" + assert Path(shards[1].input_file).name == "file2.tsv" + + +@contextlib.contextmanager +def assert_warns(caplog, *, match: str) -> Iterator[None]: + caplog.clear() + sh.warn_once.cache_clear() + + with caplog.at_level(logging.WARN): + yield + assert len(caplog.messages) == 1 + assert re.match(match, caplog.messages[0]) + caplog.clear() + + +def test_text_sharding_config(tmp_path: Path): + df1 = get_random_table(10**3, seed=111).to_pandas() + df2 = get_random_table(10**2, seed=222).to_pandas() + + df1.to_csv(tmp_path / "file1.csv", index=False) + df2.to_csv(tmp_path / "file2.csv", index=False) + + tsc = TextShardingConfig( + input_file=str(tmp_path / "file*.csv"), + batch_size=11, + columns=["score", "cat"], + batch_format=BatchFormat.PANDAS, + filters_expr="(ds.field('score') < 0) & (ds.field('cat') >= 5)", + sep=",", + ) + shards: List[pd.DataFrame] = tsc.make_shards() + assert len(shards) == 2 + full_batches: pd.DataFrame = concat_batches( + [bb for shard in shards for bb in shard.to_batches(batch_size=None)] + ) + assert len(full_batches) == 280 + expected_df: pd.DataFrame = concat_batches([df1, df2]) + expected_df = expected_df[(expected_df["score"] < 0) & (expected_df["cat"] >= 5)] + assert list(full_batches.columns) == ["cat", "name", "score"] + assert permutationally_equal_dataframes(full_batches, expected_df) + + tsc.path_column = "my_custom_path_name" + another_full_batches: pd.DataFrame = concat_batches( + [ + bb + for shard in tsc.make_shards() + for bb in shard.to_batches(batch_size=tsc.batch_size) + ] + ) + + assert list(another_full_batches.columns) == [ + "cat", + "name", + "score", + "my_custom_path_name", + ] + assert list(another_full_batches["my_custom_path_name"].unique()) == [ + str(tmp_path / "file1.csv"), + str(tmp_path / "file2.csv"), + ] + + another_full_batches = another_full_batches[["cat", "name", "score"]] + assert permutationally_equal_dataframes(full_batches, another_full_batches) + + assert list(map(len, shards[0].to_batches(batch_size=tsc.batch_size))) == [ + 11 + ] * 22 + [10] + assert list(map(len, shards[1].to_batches(batch_size=tsc.batch_size))) == [ + 11 + ] * 2 + [6] + + single_df_few_columns = next( + shards[0].to_batches( + batch_size=tsc.batch_size, + columns=tsc.columns, + batch_format=tsc.batch_format, + ) + ) + assert list(single_df_few_columns.columns) == tsc.columns + single_df_all_columns = next( + shards[0].to_batches(batch_size=tsc.batch_size, batch_format=tsc.batch_format) + ) + assert permutationally_equal_dataframes( + single_df_few_columns, single_df_all_columns[tsc.columns] + ) + + +def test_output_text_config(tmp_path: Path): + toc = TextOutputConfig(str(tmp_path), compression="tar", sep="\t") + table1 = get_random_table(10**3, seed=211) + + paths_ = toc.write_batch(table1, iteration_index=(555,)) + assert len(paths_) == 1 + assert str(paths_[0]).endswith("555.tsv.tar") + reload_df = pd.read_csv(paths_[0], sep=toc.sep) + assert permutationally_equal_dataframes(reload_df, table1.to_pandas()) + + table2 = get_random_table(10**2, seed=231) + + paths_ = toc.write_batch(table2, iteration_index=(777,)) + assert len(paths_) == 1 + assert str(paths_[0]).endswith("777.tsv.tar") + reload_df = pd.read_csv(paths_[0], sep=toc.sep) + assert permutationally_equal_dataframes(reload_df, table2.to_pandas()) + + +def test_list_header_in_text_sharding_config(tmp_path: Path): + df1 = get_random_table(10**3, seed=111).to_pandas() + df2 = get_random_table(10**2, seed=222).to_pandas() + + df1.to_csv(tmp_path / "file1.csv", index=False) + df2.to_csv(tmp_path / "file2.csv", index=False) + + tsc = TextShardingConfig( + input_file=str(tmp_path / "file*.csv"), + batch_size=11, + header=["CAT", "NAME", "SCORE"], # upcase for tests + columns=["SCORE", "CAT"], + filters_expr="(ds.field('SCORE') < 0) & (ds.field('CAT') >= 5)", + batch_format=BatchFormat.PANDAS, + sep=",", + ) + shards: List[pd.DataFrame] = tsc.make_shards() + assert len(shards) == 2 + full_batches: pd.DataFrame = concat_batches( + [bb for shard in shards for bb in shard.to_batches(batch_size=None)] + ) + assert len(full_batches) == 280 + expected_df: pd.DataFrame = concat_batches([df1, df2]) + expected_df = expected_df[(expected_df["score"] < 0) & (expected_df["cat"] >= 5)] + assert list(full_batches.columns) == ["CAT", "NAME", "SCORE"] + assert permutationally_equal_dataframes( + full_batches, expected_df.rename(columns=str.upper) + )