From 42892d832b91bd83b23fbbf87f01a78ce5760ca8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 2 Jan 2025 12:12:07 -0500 Subject: [PATCH 01/18] TPS-free 2D bucket estimation and filtering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 16 ++- .../common/data/lhotse/sampling.py | 110 +++++++++----- .../estimate_duration_bins_2d.py | 62 ++++---- .../common/test_2d_bucketing_constraint.py | 136 +++++++++++++++--- 4 files changed, 238 insertions(+), 86 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index bad866e6dac9..c7a772d56e82 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -33,7 +33,7 @@ make_worker_init_fn, ) from lhotse.dataset.dataloading import resolve_seed -from lhotse.dataset.sampling.base import CutSampler, TimeConstraint +from lhotse.dataset.sampling.base import CutSampler, SamplingConstraint, TimeConstraint from lhotse.lazy import LazyFlattener from lhotse.utils import fastcopy, fix_random_seed from omegaconf import DictConfig, OmegaConf @@ -44,6 +44,7 @@ read_cutset_from_config, ) from nemo.collections.common.data.lhotse.sampling import ( + BucketingFilter, DurationFilter, FixedBucketBatchSizeConstraint2D, MultimodalFixedBucketBatchSizeConstraint2D, @@ -91,6 +92,7 @@ class LhotseDataLoadingConfig: bucket_duration_bins: Any = None # list[float] | list[list[float]] | None = None bucket_buffer_size: int = 10000 concurrent_bucketing: bool = True # fetches data in a background thread + bucketing_2d_strict_mode: bool = True # reduces padding by discarding significant outliers # d. Other Lhotse sampling options. shuffle_buffer_size: int | None = 10000 drop_last: bool = False @@ -530,7 +532,7 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # Select the strategy customizing Lhotse sampler behaviour. # Provides support for dynamic batch sizes, multimodal dataloading, 2D bucketing, etc. bucket_duration_bins = determine_bucket_duration_bins(config) - constraint = determine_sampling_constraint(bucket_duration_bins, config) + cuts, constraint = determine_sampling_constraint(cuts, bucket_duration_bins, config) # 3. The sampler. if config.use_bucketing: @@ -608,13 +610,15 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No return sampler, use_iterable_dataset -def determine_sampling_constraint(bucket_duration_bins, config): +def determine_sampling_constraint(cuts: CutSet, bucket_duration_bins, config) -> tuple[CutSet, SamplingConstraint]: """ Select an appropriate sampling strategy (constraint) for Lhotse samplers based on the configuration. Sampling constraint affects the batch size (static/dynamic) and bucketing behaviour (1D/2D). It is the appropriate customization point to introduce support of other modalities, as it defines a method for example sequence length measurement (audio duration, text tokens, etc.). + Some constraints apply extra filter on ``cuts`` which is why we accept and return the ``CutSet``. + Lhotse's default is :class:`TimeConstraint` for regular audio data, other available options are multimodal constraints (joint text + audio) and their 2D bucketing extensions. """ @@ -627,7 +631,9 @@ def determine_sampling_constraint(bucket_duration_bins, config): max_seq_len_buckets=bucket_duration_bins, batch_sizes=config.bucket_batch_size, token_equivalent_duration=config.token_equivalent_duration, + strict_2d=config.bucketing_2d_strict_mode, ) + cuts = cuts.filter(BucketingFilter(constraint)) else: constraint = MultimodalSamplingConstraint( token_equivalent_duration=config.token_equivalent_duration, @@ -643,14 +649,16 @@ def determine_sampling_constraint(bucket_duration_bins, config): constraint = FixedBucketBatchSizeConstraint2D( max_seq_len_buckets=bucket_duration_bins, batch_sizes=config.bucket_batch_size, + strict_2d=config.bucketing_2d_strict_mode, ) + cuts = cuts.filter(BucketingFilter(constraint)) else: constraint = TimeConstraint( max_cuts=config.batch_size, max_duration=config.batch_duration, quadratic_duration=config.quadratic_duration, ) - return constraint + return cuts, constraint def determine_bucket_duration_bins(config): diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index d645e3816300..c2d3dcd8be37 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -15,9 +15,11 @@ import bisect import logging import math +from bisect import bisect_left, bisect_right from dataclasses import dataclass from typing import Any, Sequence +import numpy as np from lhotse.cut import Cut from lhotse.dataset import SamplingConstraint, TokenConstraint from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint @@ -110,11 +112,20 @@ class FixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint): """ Sampling strategy that customizes Lhotse samplers to support 2D bucket selection (it also supports 1D). It is intended only for audio examples (i.e., Lhotse Cut objects). + + When ``strict_2d`` is set, we only consider sub-buckets for a single bucket that is the best match. + When set to ``False``, we'll promote an example to buckets with larger 1st dim if they can accommodate the 2nd dim. """ + strict_2d: bool = True + + def __post_init__(self): + if isinstance(self.max_seq_len_buckets[0], Sequence): + self.max_seq_len_buckets = np.asarray(self.max_seq_len_buckets) + @property def bucketing_2d_enabled(self) -> bool: - return isinstance(self.max_seq_len_buckets[0], Sequence) and len(self.max_seq_len_buckets[0]) == 2 + return isinstance(self.max_seq_len_buckets, np.ndarray) def measure_length(self, example: Cut) -> tuple[float, float] | float: if self.bucketing_2d_enabled: @@ -123,41 +134,54 @@ def measure_length(self, example: Cut) -> tuple[float, float] | float: return example.duration def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = None) -> int: - if not self.bucketing_2d_enabled: - return super().select_bucket(buckets=buckets, example=example, example_len=example_len) if example_len is None: example_len = self.measure_length(example) - bucket_idx = bisect.bisect_left(buckets, example_len) - # For 2D bucketing we have to refine the initially found bucket_idx, as bisect - # looks primarily at the first index of a tuple (i.e. duration). - # For example, with buckets [(1, 1), (1, 2), (2, 2), (2, 4)] and example (1.5, 3) - # bisect would allocate it to bucket_idx=2 instead of bucket_idx=3. - # To refine, we'll try to push the example to as many buckets to the right as possible, - # as long as they have the same dim0 length (e.g. audio duration) and the example's dim1 - # is smaller than the bin's dim1 (e.g., output token sequence length). - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - num_buckets = len(self.max_seq_len_buckets) - while ( - (next_idx := bucket_idx + 1) < num_buckets # There is a next bucket - and (bin := self.max_seq_len_buckets[next_idx])[0] == bin_dim0 # The next bucket has the same 1st dim. - # The example's 2nd dim is between that of the current and the next bucket; or, - # the next bucket's 2nd dim is still smaller than example. - and (bin_dim1 < example_len[1] <= bin[1] or bin[1] < example_len[1]) - ): - bucket_idx = next_idx - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - - if example_len[0] > bin_dim0 or example_len[1] > bin_dim1: - logging.warning( - f"Data sample exceeds 2D bucket specification: lengths={example_len} bucket=({bin_dim0}, {bin_dim1}) " - f"(there is no larger bucket that would fit this example). " - f"We will keep it but expect OutOfMemoryError to happen during the training. " - f"You can fix this by stricter filtering with max_duration, max_tokens, max_tps, max_tpt; " - f"or re-estimating your bucket bins to match the actual data length distribution. " - f"Details: {example=}" - ) - - return bucket_idx + return find_smallest_bucket(self.max_seq_len_buckets, example_len, strict=self.strict_2d) + + +def find_smallest_bucket( + buckets: np.ndarray, example_lens: float | Sequence[float], strict: bool = True +) -> int | None: + """ + Find the smallest bucket that fits a given example. + Each bucket and ``example_lens`` are floats (1-D bucketing) + or tuples of (dim0, dim1, dim2, ...) (N-D bucketing, typically 2-D). + Assumes the buckets have been sorted ascendingly. + Returns a tuple of (smallest_bin, bin_idx), or (None, None) if no bucket fits the example. + """ + # 1D bucketing - binary search. + if isinstance(example_lens, float): # 1-D + idx = bisect_left(buckets, example_lens) + if idx == len(buckets): + return None + return idx + + # 2D bucketing 'strict' mode: only consider sub-buckets for the specific bucket that matches this example. + # E.g. for buckets = [(10, 5), (10, 10), (20, 12), (20, 18)] + # and example_lens = (8, 11) + # we will return None because we only consider the first two buckets based on dim0 (=8). + if strict: + dim0_begin = bisect_left(buckets[:, 0], example_lens[0]) + if dim0_begin == buckets.shape[0]: + return None + dim0_end = dim0_begin + while dim0_end < buckets.shape[0] and buckets[dim0_end, 0] == buckets[dim0_begin, 0]: + dim0_end += 1 + dim1_begin = bisect_left(buckets[dim0_begin:dim0_end, 1], example_lens[1]) + if dim1_begin == dim0_end - dim0_begin: + return None + return dim0_begin + dim1_begin + + # 2D bucketing 'lenient' mode - linear search (as 2nd dim may not be growing monotonically). + # E.g. for buckets = [(10, 5), (10, 10), (20, 12), (20, 18)] + # and example_lens = (8, 11) + # we will return bucket_idx=2 because (20, 12) fits (8, 11) at the cost of more padding. + does_fit = np.all(np.asarray(example_lens) <= buckets, axis=1) + min_fit_idx = np.argmax(does_fit) + if min_fit_idx or does_fit[min_fit_idx]: + return min_fit_idx.item() + else: + return None @dataclass @@ -301,6 +325,24 @@ def __call__(self, example) -> bool: return self.tpt_min <= tpt <= self.tpt_max +class BucketingFilter: + """ + Filters out examples that did not fit into any of the buckets. + Intended mainly for 2D bucketing. This filter is only active when + the constraint passed to it is of type ``FixedBucketBatchSizeConstraint2D``, + and is otherwise disabled. + """ + + def __init__(self, sampling_constraint: SamplingConstraint) -> None: + self.constraint = sampling_constraint + self.enabled = isinstance(self.constraint, FixedBucketBatchSizeConstraint2D) + + def __call__(self, example) -> bool: + if not self.enabled: + return True + return self.constraint.select_bucket(self.constraint.max_seq_len_buckets, example) is not None + + def _measure_tokens(cut: Cut) -> int: if hasattr(cut, "input_ids"): return len(cut.input_ids) # tokenized with prompt formatter diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 0f4a021e09cc..7f14ef836058 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -15,6 +15,7 @@ import argparse import ast import math +import warnings from functools import partial from itertools import islice from pathlib import Path @@ -26,13 +27,10 @@ from omegaconf import OmegaConf from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.data import apply_prompt_format_fn from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize -from nemo.collections.common.data.lhotse.sampling import ( - DurationFilter, - FixedBucketBatchSizeConstraint2D, - TokenPerSecondFilter, -) +from nemo.collections.common.data.lhotse.sampling import DurationFilter, FixedBucketBatchSizeConstraint2D from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer @@ -107,11 +105,7 @@ def parse_args(): help="If specified, we'll filter out utterances longer than this.", ) parser.add_argument( - "--max_tps", - type=float, - default=float("inf"), - help="If specified, we'll filter out utterances with more tokens/second than this. " - "On regular utterances and BPE tokenizers with 1024 tokens 10-12tps is generally a reasonable limit.", + "--max_tps", type=float, default=None, help="Deprecated. TPS is automatically determined per bucket." ) parser.add_argument( "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." @@ -174,14 +168,6 @@ def estimate_duration_buckets( if math.isinf(max_duration): max_duration = sizes[-1] - tps = num_tokens / sizes - if not quiet: - print("Token per second distribution:") - print(pd.Series(tps).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) - if math.isinf(max_tps): - max_tps = tps.max() - del tps - bins = [] bin_indexes = [0] tot = 0.0 @@ -193,8 +179,20 @@ def _estimate_token_buckets(max_bucket_duration): # Note that this estimation is biased towards more padding if you have # a lot of zero-token examples (e.g. non-speech). nonlocal bins + + # Start by discarding outlier examples as defined by token-per-second (TPS) attribute. + # We empirically determined high TPS examples to cause severe OOMs limiting batch sizes. + # We cap the TPS for each top-level bucket at 4 standard deviations of TPS. + # Examples exceeding that TPS value will be discarded during sampling at training time. num_tokens_bucket = num_tokens[bin_indexes[-1] : binidx] + non_outlier_indexes = find_non_outliers_z_score(num_tokens_bucket / sizes[bin_indexes[-1] : binidx]) + num_tokens_bucket = num_tokens[non_outlier_indexes] num_tokens_bucket.sort() + if not quiet: + print( + f"[bucket <={max_bucket_duration}s] [{num_tokens_bucket.min()} - {num_tokens_bucket.max()}] Discarded {binidx - bin_indexes[-1] - len(num_tokens_bucket)} TPS outliers." + ) + tokens_per_subbucket = num_tokens_bucket.sum() / num_subbuckets tot_toks = 0 # Iterate over token counts, and whenever we hit tokens_per_subbucket, create a new 2D bucket bin. @@ -204,7 +202,7 @@ def _estimate_token_buckets(max_bucket_duration): bins.append((max_bucket_duration, num_toks)) tot_toks = 0 tot_toks += num_toks - bins.append((size, math.ceil(size * max_tps))) + bins.append((size, num_toks)) # Iterate over data, and whenever we hit size_per_bucket, create a new bucket bin. for binidx, size in enumerate(sizes): @@ -220,6 +218,11 @@ def _estimate_token_buckets(max_bucket_duration): return bins +def find_non_outliers_z_score(data, threshold=4): + z_scores = np.abs((data - np.mean(data)) / np.std(data)) + return np.where(z_scores <= threshold) + + def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrapper: if len(paths) == 1: tok = SentencePieceTokenizer(paths[0]) @@ -233,15 +236,8 @@ def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrappe def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): if prompt is not None: - turns = prompt.get_default_dialog_slots() - last_turn = {"role": prompt.OUTPUT_ROLE, "slots": prompt.get_slots(prompt.OUTPUT_ROLE)} - assert len(last_turn["slots"]) == 1 # TODO: not sure how to handle multi-slot for system output here - for key in last_turn["slots"]: - last_turn["slots"][key] = cut.supervisions[0].text - last_turn["slots"][prompt.PROMPT_LANGUAGE_SLOT] = cut.supervisions[0].language - turns.append(last_turn) - ans = prompt.encode_dialog(turns) - cut.supervisions[0].tokens = ans["input_ids"] + encoded = apply_prompt_format_fn(cut, prompt) + cut.supervisions[0].tokens = encoded["input_ids"] elif tokenizer is not None: cut = tokenize(cut, tokenizer) @@ -274,6 +270,12 @@ def main(): if not args.quiet: pd.set_option('display.float_format', lambda x: '%.2f' % x) + if args.max_tps is not None: + warnings.warn( + "The option --max_tps has been deprecated in favor of " + "automatic TPS determination that's variable across buckets." + ) + tokenizer = None prompt = None if args.tokenizer is not None: @@ -302,8 +304,6 @@ def main(): duration_filter = RejectionsCounter(DurationFilter(args.min_duration, args.max_duration), "Duration filtering") cuts = cuts.filter(duration_filter) cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt)) - tps_filter = RejectionsCounter(TokenPerSecondFilter(-1, args.max_tps), "Token per second filtering") - cuts = cuts.filter(tps_filter) if (N := args.num_examples) > 0: cuts = islice(cuts, N) @@ -311,7 +311,6 @@ def main(): cuts, num_buckets=args.buckets, num_subbuckets=args.sub_buckets, - max_tps=args.max_tps, max_duration=args.max_duration, quiet=args.quiet, ) @@ -320,7 +319,6 @@ def main(): print(duration_bins) return duration_filter.print_report() - tps_filter.print_report() print("Use the following options in your config:") print(f"\tnum_buckets={args.buckets}") print(f"\tbucket_duration_bins={duration_bins}") diff --git a/tests/collections/common/test_2d_bucketing_constraint.py b/tests/collections/common/test_2d_bucketing_constraint.py index 36cb9825ac5b..ff58ea156cdc 100644 --- a/tests/collections/common/test_2d_bucketing_constraint.py +++ b/tests/collections/common/test_2d_bucketing_constraint.py @@ -16,27 +16,28 @@ import pytest from lhotse import CutSet, Seconds, SupervisionSegment from lhotse.dataset import DynamicBucketingSampler -from lhotse.testing.dummies import DummyManifest, dummy_cut -from nemo.collections.common.data.lhotse.sampling import FixedBucketBatchSizeConstraint2D +from lhotse.testing.dummies import dummy_cut +from nemo.collections.common.data.lhotse.dataloader import BucketingFilter, FixedBucketBatchSizeConstraint2D + + +def make_cut(id_: int = 0, duration: Seconds = 1.0, num_tokens: int = 10): + supervision = SupervisionSegment(f"blah-{id_}", f"blah-{id_}", 0.0, duration, text="a" * num_tokens) + supervision.tokens = np.zeros((num_tokens,), dtype=np.int32) + return dummy_cut(id_, duration=duration, supervisions=[supervision]) @pytest.fixture def cuts(): - def _cut(id_: int, duration: Seconds, num_tokens: int): - supervision = SupervisionSegment(f"blah-{id_}", f"blah-{id_}", 0.0, duration, text="a" * num_tokens) - supervision.tokens = np.zeros((num_tokens,), dtype=np.int32) - return dummy_cut(id_, duration=duration, supervisions=[supervision]) - return CutSet( - [_cut(i, duration=2.0, num_tokens=4) for i in range(20)] - + [_cut(i, duration=2.0, num_tokens=8) for i in range(20)] - + [_cut(i, duration=2.0, num_tokens=12) for i in range(20)] - + [_cut(i, duration=8.0, num_tokens=8) for i in range(20)] - + [_cut(i, duration=8.0, num_tokens=12) for i in range(20)] - + [_cut(i, duration=8.0, num_tokens=16) for i in range(20)] - + [_cut(i, duration=14.0, num_tokens=12) for i in range(20)] - + [_cut(i, duration=14.0, num_tokens=16) for i in range(20)] - + [_cut(i, duration=14.0, num_tokens=20) for i in range(20)] + [make_cut(i, duration=2.0, num_tokens=4) for i in range(20)] + + [make_cut(i, duration=2.0, num_tokens=8) for i in range(20)] + + [make_cut(i, duration=2.0, num_tokens=12) for i in range(20)] + + [make_cut(i, duration=8.0, num_tokens=8) for i in range(20)] + + [make_cut(i, duration=8.0, num_tokens=12) for i in range(20)] + + [make_cut(i, duration=8.0, num_tokens=16) for i in range(20)] + + [make_cut(i, duration=14.0, num_tokens=12) for i in range(20)] + + [make_cut(i, duration=14.0, num_tokens=16) for i in range(20)] + + [make_cut(i, duration=14.0, num_tokens=20) for i in range(20)] ) @@ -63,6 +64,7 @@ def test_2d_bucketing_expected_bucket_allocation(cuts): constraint=FixedBucketBatchSizeConstraint2D( max_seq_len_buckets=duration_bins, batch_sizes=batch_sizes, + strict_2d=False, ), buffer_size=1000, seed=0, @@ -101,3 +103,105 @@ def test_2d_bucketing_expected_bucket_allocation(cuts): except ValueError as e: if "max() arg is an empty sequence" not in str(e): raise + + +@pytest.mark.parametrize( + ["duration", "num_tokens", "should_keep", "bucket_idx"], + [ + # Buckets for duration range [0.0-5.0]: + # * Sweep num_tokens + (2.0, 0, True, 0), + (2.0, 5, True, 0), + (2.0, 10, True, 0), + (2.0, 11, True, 1), + (2.0, 20, True, 1), + (2.0, 21, True, 3), + (2.0, 30, True, 3), + (2.0, 31, False, None), + # * Check the upper bound duration 5.0 + (5.0, 0, True, 0), + (5.0, 5, True, 0), + (5.0, 10, True, 0), + (5.0, 11, True, 1), + (5.0, 20, True, 1), + (5.0, 21, True, 3), + (5.0, 30, True, 3), + (5.0, 31, False, None), + # Buckets for duration range [5.0, 10.0] + # * Sweep num_tokens + (8.0, 0, True, 2), + (8.0, 15, True, 2), + (8.0, 16, True, 3), + (8.0, 30, True, 3), + (8.0, 31, False, None), + # * Check the upper bound duration 10.0 + (10.0, 0, True, 2), + (10.0, 15, True, 2), + (10.0, 16, True, 3), + (10.0, 30, True, 3), + (10.0, 31, False, None), + # Durations above max duration + (20.0, 0, False, None), + (20.0, 1000, False, None), + ], +) +def test_2d_bucketing_filter_lenient(duration, num_tokens, should_keep, bucket_idx): + buckets = [(5.0, 10), (5.0, 20), (10.0, 15), (10.0, 30)] + batch_sizes = [4, 3, 2, 1] + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=False) + filter_2d = BucketingFilter(constraint) + + cut = make_cut(duration=duration, num_tokens=num_tokens) + assert filter_2d(cut) == should_keep + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == bucket_idx + + +@pytest.mark.parametrize( + ["duration", "num_tokens", "should_keep", "bucket_idx"], + [ + # Buckets for duration range [0.0-5.0]: + # * Sweep num_tokens + (2.0, 0, True, 0), + (2.0, 5, True, 0), + (2.0, 10, True, 0), + (2.0, 11, True, 1), + (2.0, 20, True, 1), + (2.0, 21, False, None), # <-- strict + (2.0, 30, False, None), # <-- strict + (2.0, 31, False, None), + # * Check the upper bound duration 5.0 + (5.0, 0, True, 0), + (5.0, 5, True, 0), + (5.0, 10, True, 0), + (5.0, 11, True, 1), + (5.0, 20, True, 1), + (5.0, 21, False, None), # <-- strict + (5.0, 30, False, None), # <-- strict + (5.0, 31, False, None), + # Buckets for duration range [5.0, 10.0] + # * Sweep num_tokens + (8.0, 0, True, 2), + (8.0, 15, True, 2), + (8.0, 16, True, 3), + (8.0, 30, True, 3), + (8.0, 31, False, None), + # * Check the upper bound duration 10.0 + (10.0, 0, True, 2), + (10.0, 15, True, 2), + (10.0, 16, True, 3), + (10.0, 30, True, 3), + (10.0, 31, False, None), + # Durations above max duration + (20.0, 0, False, None), + (20.0, 1000, False, None), + ], +) +def test_2d_bucketing_filter_strict(duration, num_tokens, should_keep, bucket_idx): + buckets = [(5.0, 10), (5.0, 20), (10.0, 15), (10.0, 30)] + batch_sizes = [4, 3, 2, 1] + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=True) + filter_2d = BucketingFilter(constraint) + + cut = make_cut(duration=duration, num_tokens=num_tokens) + assert filter_2d(cut) == should_keep + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == bucket_idx From 8291b097f21d184d838b67cebe0b2b9d12f55bb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 2 Jan 2025 12:22:02 -0500 Subject: [PATCH 02/18] fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../estimate_duration_bins_2d.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 7f14ef836058..8306d933579a 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -26,13 +26,17 @@ from lhotse.cut import Cut from omegaconf import OmegaConf -from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.common.data import apply_prompt_format_fn from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize from nemo.collections.common.data.lhotse.sampling import DurationFilter, FixedBucketBatchSizeConstraint2D from nemo.collections.common.prompts.formatter import PromptFormatter -from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer +from nemo.collections.common.tokenizers import ( + AggregateTokenizer, + CanaryTokenizer, + SentencePieceTokenizer, + TokenizerSpec, +) def parse_args(): @@ -164,7 +168,7 @@ def estimate_duration_buckets( if not quiet: print("Duration distribution:") - print(pd.Series(sizes).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) + print(pd.Series(sizes).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99, 0.995, 0.999])) if math.isinf(max_duration): max_duration = sizes[-1] @@ -223,15 +227,19 @@ def find_non_outliers_z_score(data, threshold=4): return np.where(z_scores <= threshold) -def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrapper: +def load_tokenizer(paths: list[str], langs: list[str] = None, is_canary: bool = True) -> TokenizerSpec: if len(paths) == 1: tok = SentencePieceTokenizer(paths[0]) else: assert langs is not None and len(paths) == len( langs ), f"Cannot create AggregateTokenizer; each tokenizer must have assigned a language via --langs option (we got --tokenizers={paths} and --langs={langs})" - tok = AggregateTokenizer({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) - return TokenizerWrapper(tok) + if is_canary: + tokcls = CanaryTokenizer + else: + tokcls = AggregateTokenizer + tok = tokcls({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) + return tok def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): @@ -279,12 +287,12 @@ def main(): tokenizer = None prompt = None if args.tokenizer is not None: - tokenizer = load_tokenizer(args.tokenizer, args.langs) + tokenizer = load_tokenizer(args.tokenizer, args.langs, 'canary' in args.prompt_format) if args.prompt_format is not None: prompt_defaults = None if args.prompt is not None: prompt_defaults = ast.literal_eval(args.prompt) - prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer._tokenizer, defaults=prompt_defaults) + prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer, defaults=prompt_defaults) if '=' in args.input: inp_arg = args.input From 48230c54ba6ac16fde31c662180366e6f52ef19a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 3 Jan 2025 09:00:22 -0500 Subject: [PATCH 03/18] extra unit test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/test_2d_bucketing_constraint.py | 62 +++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/tests/collections/common/test_2d_bucketing_constraint.py b/tests/collections/common/test_2d_bucketing_constraint.py index ff58ea156cdc..21bda0ce8c3f 100644 --- a/tests/collections/common/test_2d_bucketing_constraint.py +++ b/tests/collections/common/test_2d_bucketing_constraint.py @@ -14,15 +14,22 @@ import numpy as np import pytest +import torch.utils.data from lhotse import CutSet, Seconds, SupervisionSegment from lhotse.dataset import DynamicBucketingSampler from lhotse.testing.dummies import dummy_cut -from nemo.collections.common.data.lhotse.dataloader import BucketingFilter, FixedBucketBatchSizeConstraint2D +from lhotse.testing.random import deterministic_rng + +from nemo.collections.common.data.lhotse.dataloader import ( + BucketingFilter, + FixedBucketBatchSizeConstraint2D, + get_lhotse_dataloader_from_config, +) def make_cut(id_: int = 0, duration: Seconds = 1.0, num_tokens: int = 10): supervision = SupervisionSegment(f"blah-{id_}", f"blah-{id_}", 0.0, duration, text="a" * num_tokens) - supervision.tokens = np.zeros((num_tokens,), dtype=np.int32) + supervision.tokens = np.zeros((num_tokens,), dtype=np.int32).tolist() return dummy_cut(id_, duration=duration, supervisions=[supervision]) @@ -81,7 +88,7 @@ def test_2d_bucketing_expected_bucket_allocation(cuts): for cut in batch: # First, check that the sampled examples are indeed below the max duration/num_tokens for its bucket. assert cut.duration <= max_duration - assert cut.supervisions[0].tokens.shape[0] <= max_num_tokens + assert len(cut.supervisions[0].tokens) <= max_num_tokens # Then, find the previous compatible bucket for each of training example's dimensions, # and verify that it was not possible to assign the example to that smaller bucket. # We should skip this for bucket_idx==0 (no previous buckets available). @@ -99,7 +106,7 @@ def test_2d_bucketing_expected_bucket_allocation(cuts): prev_max_num_tokens = max( tok for dur, tok in duration_bins[:bin_index] if dur == max_duration and tok < max_num_tokens ) - assert cut.supervisions[0].tokens.shape[0] > prev_max_num_tokens + assert len(cut.supervisions[0].tokens) > prev_max_num_tokens except ValueError as e: if "max() arg is an empty sequence" not in str(e): raise @@ -205,3 +212,50 @@ def test_2d_bucketing_filter_strict(duration, num_tokens, should_keep, bucket_id cut = make_cut(duration=duration, num_tokens=num_tokens) assert filter_2d(cut) == should_keep assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == bucket_idx + + +class _Identity(torch.utils.data.Dataset): + def __getitem__(self, item): + return item + + +def test_2d_bucketing_strict_mode_flag_works(deterministic_rng, tmp_path): + cuts_path = tmp_path / "cuts.jsonl" + CutSet([make_cut(0, duration=1.0, num_tokens=10), make_cut(0, duration=1.0, num_tokens=100)]).to_file(cuts_path) + + # Strict mode enabled + dloader = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "use_bucketing": True, + "bucket_duration_bins": [(5.0, 10), (5.0, 20), (10.0, 150), (10.0, 300)], + "bucket_batch_size": [1, 1, 1, 1], + "bucketing_2d_strict_mode": True, + }, + global_rank=0, + world_size=1, + dataset=_Identity(), + ) + batches = [b for b in dloader] + assert len(batches) == 1 + assert len(batches[0]) == 1 + assert len(batches[0][0].supervisions[0].tokens) == 10 + + # Strict mode disabled + dloader = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "use_bucketing": True, + "bucket_duration_bins": [(5.0, 10), (5.0, 20), (10.0, 150), (10.0, 300)], + "bucket_batch_size": [1, 1, 1, 1], + "bucketing_2d_strict_mode": False, + }, + global_rank=0, + world_size=1, + dataset=_Identity(), + ) + batches = [b for b in dloader] + assert len(batches) == 2 + assert len(batches[0]) == 1 + assert len(batches[0][0].supervisions[0].tokens) == 100 + assert len(batches[1][0].supervisions[0].tokens) == 10 From c48b813b9139f409e3c55acdfc21aba9168a7d26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 6 Jan 2025 10:18:24 -0800 Subject: [PATCH 04/18] Fixes in 2D bin estimation script and OOMptimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../estimate_duration_bins_2d.py | 62 +++++++++++++------ scripts/speech_recognition/oomptimizer.py | 2 +- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 8306d933579a..4ede4df3872d 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -111,6 +111,10 @@ def parse_args(): parser.add_argument( "--max_tps", type=float, default=None, help="Deprecated. TPS is automatically determined per bucket." ) + parser.add_argument( + "--token_outlier_threshold", type=float, default=4.0, help="The lower this is, the more outliers in transcript token count will be filtered out. " + "By default allow token counts at 4 sigma away from distribution mean, computed separately for every bucket." + ) parser.add_argument( "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." ) @@ -132,12 +136,19 @@ def parse_args(): return parser.parse_args() +def sort_two_arrays(A, B): + joint = np.rec.fromarrays([A, B]) + joint.sort() + return joint.f0, joint.f1 + + def estimate_duration_buckets( cuts: Iterable[Cut], num_buckets: int, num_subbuckets: int, max_tps: float, max_duration: float, + token_outlier_threshold: float, quiet: bool, ) -> list[tuple[float, float]]: """ @@ -157,10 +168,7 @@ def estimate_duration_buckets( num_tokens.append(toks) sizes = np.array(sizes, dtype=np.float32) num_tokens = np.array(num_tokens, dtype=np.int32) - joint = np.rec.fromarrays([sizes, num_tokens]) - joint.sort() - sizes = joint.f0 - num_tokens = joint.f1 + sizes, num_tokens = sort_two_arrays(sizes, num_tokens) # We are building buckets with equal duration (empirically leads to more even bucket exhaustion over time). # We need to determine how much duration to allocate per bucket. @@ -173,6 +181,7 @@ def estimate_duration_buckets( max_duration = sizes[-1] bins = [] + tps_thresholds = [] bin_indexes = [0] tot = 0.0 @@ -188,42 +197,54 @@ def _estimate_token_buckets(max_bucket_duration): # We empirically determined high TPS examples to cause severe OOMs limiting batch sizes. # We cap the TPS for each top-level bucket at 4 standard deviations of TPS. # Examples exceeding that TPS value will be discarded during sampling at training time. - num_tokens_bucket = num_tokens[bin_indexes[-1] : binidx] - non_outlier_indexes = find_non_outliers_z_score(num_tokens_bucket / sizes[bin_indexes[-1] : binidx]) - num_tokens_bucket = num_tokens[non_outlier_indexes] - num_tokens_bucket.sort() + num_tokens_bucket_all = num_tokens[bin_indexes[-1] : binidx] + sizes_bucket_all = sizes[bin_indexes[-1] : binidx] + non_outlier_indexes = find_non_outliers_z_score(num_tokens_bucket_all / sizes_bucket_all, threshold=token_outlier_threshold) + num_tokens_bucket = num_tokens_bucket_all[non_outlier_indexes] + sizes_bucket = sizes_bucket_all[non_outlier_indexes] + max_tps_bucket = (num_tokens_bucket / sizes_bucket).max() + num_tokens_bucket, sizes_bucket = sort_two_arrays(num_tokens_bucket, sizes_bucket) if not quiet: + outlier_tps = np.delete(num_tokens_bucket_all / sizes_bucket_all, non_outlier_indexes) print( - f"[bucket <={max_bucket_duration}s] [{num_tokens_bucket.min()} - {num_tokens_bucket.max()}] Discarded {binidx - bin_indexes[-1] - len(num_tokens_bucket)} TPS outliers." + f"[bucket <= {max_bucket_duration:.2f}s] [{num_tokens_bucket.min()} - {num_tokens_bucket.max()}] [approx-max-tps: {max_tps_bucket:.2f}] Discarded {binidx - bin_indexes[-1] - len(num_tokens_bucket)} max token outliers", end=" " ) + if len(outlier_tps) > 0: + print(f"min-outlier: {outlier_tps.min():.2f}, max-outlier: {outlier_tps.max():.2f}).", end="") + print() tokens_per_subbucket = num_tokens_bucket.sum() / num_subbuckets tot_toks = 0 # Iterate over token counts, and whenever we hit tokens_per_subbucket, create a new 2D bucket bin. - for num_toks in num_tokens_bucket: + for num_toks, size in zip(num_tokens_bucket, sizes_bucket): # Threshold hit: we are creating a new (max_duration, max_num_tokens) bin. if tot_toks > tokens_per_subbucket: bins.append((max_bucket_duration, num_toks)) + tps_thresholds.append(max_tps_bucket) tot_toks = 0 tot_toks += num_toks - bins.append((size, num_toks)) + bins.append((max_bucket_duration, num_toks)) + tps_thresholds.append(max_tps_bucket) # Iterate over data, and whenever we hit size_per_bucket, create a new bucket bin. for binidx, size in enumerate(sizes): if tot > size_per_bucket: # Threshold hit: we are creating a new duration bin (multiplied by number of token bins). _estimate_token_buckets(max_bucket_duration=size) + bin_indexes.append(binidx) tot = 0.0 tot += size # Estimate an extra 2D bin set for global max duration. _estimate_token_buckets(max_bucket_duration=max_duration) - return bins + return bins, tps_thresholds def find_non_outliers_z_score(data, threshold=4): - z_scores = np.abs((data - np.mean(data)) / np.std(data)) + # Note: we don't apply abs() here because we only filter the upper end of the distribution. + # We don't mind low-token-counts for bucketing purposes. + z_scores = (data - np.mean(data)) / np.std(data) return np.where(z_scores <= threshold) @@ -315,21 +336,24 @@ def main(): if (N := args.num_examples) > 0: cuts = islice(cuts, N) - duration_bins = estimate_duration_buckets( + duration_bins, tps_thresholds = estimate_duration_buckets( cuts, num_buckets=args.buckets, num_subbuckets=args.sub_buckets, max_duration=args.max_duration, + max_tps=args.max_tps, + token_outlier_threshold=args.token_outlier_threshold, quiet=args.quiet, ) duration_bins = "[" + ','.join(f"[{b:.3f},{sb:d}]" for b, sb in duration_bins) + "]" - if args.quiet: - print(duration_bins) - return - duration_filter.print_report() - print("Use the following options in your config:") + tps_thresholds = "[" + ",".join(f"{t:.2f}" for t in tps_thresholds) + "]" + if not args.quiet: + duration_filter.print_report() + print("Use the following options in your config:") + print(f"\tuse_bucketing=1") print(f"\tnum_buckets={args.buckets}") print(f"\tbucket_duration_bins={duration_bins}") + print(f"\tmax_tps={tps_thresholds}") if __name__ == "__main__": diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index b44c2c46c629..8e95641fc7ef 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -408,7 +408,7 @@ def oomptimizer( ( "text" if any( - isinstance(item["type"].elements_type, LabelsType) and item["seq_length"] == direction + isinstance(item["type"], NeuralType) and isinstance(item["type"].elements_type, LabelsType) and item["seq_length"] == direction for item in schema["inputs"] if item["type"] != "dummy" ) From 83cd8df25afd7bc06445cd4b45929bb0b9b8c033 Mon Sep 17 00:00:00 2001 From: pzelasko Date: Mon, 6 Jan 2025 18:19:34 +0000 Subject: [PATCH 05/18] Apply isort and black reformatting Signed-off-by: pzelasko --- .../estimate_duration_bins_2d.py | 14 ++++++++++---- scripts/speech_recognition/oomptimizer.py | 4 +++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 4ede4df3872d..1c57eb2c583d 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -112,8 +112,11 @@ def parse_args(): "--max_tps", type=float, default=None, help="Deprecated. TPS is automatically determined per bucket." ) parser.add_argument( - "--token_outlier_threshold", type=float, default=4.0, help="The lower this is, the more outliers in transcript token count will be filtered out. " - "By default allow token counts at 4 sigma away from distribution mean, computed separately for every bucket." + "--token_outlier_threshold", + type=float, + default=4.0, + help="The lower this is, the more outliers in transcript token count will be filtered out. " + "By default allow token counts at 4 sigma away from distribution mean, computed separately for every bucket.", ) parser.add_argument( "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." @@ -199,7 +202,9 @@ def _estimate_token_buckets(max_bucket_duration): # Examples exceeding that TPS value will be discarded during sampling at training time. num_tokens_bucket_all = num_tokens[bin_indexes[-1] : binidx] sizes_bucket_all = sizes[bin_indexes[-1] : binidx] - non_outlier_indexes = find_non_outliers_z_score(num_tokens_bucket_all / sizes_bucket_all, threshold=token_outlier_threshold) + non_outlier_indexes = find_non_outliers_z_score( + num_tokens_bucket_all / sizes_bucket_all, threshold=token_outlier_threshold + ) num_tokens_bucket = num_tokens_bucket_all[non_outlier_indexes] sizes_bucket = sizes_bucket_all[non_outlier_indexes] max_tps_bucket = (num_tokens_bucket / sizes_bucket).max() @@ -207,7 +212,8 @@ def _estimate_token_buckets(max_bucket_duration): if not quiet: outlier_tps = np.delete(num_tokens_bucket_all / sizes_bucket_all, non_outlier_indexes) print( - f"[bucket <= {max_bucket_duration:.2f}s] [{num_tokens_bucket.min()} - {num_tokens_bucket.max()}] [approx-max-tps: {max_tps_bucket:.2f}] Discarded {binidx - bin_indexes[-1] - len(num_tokens_bucket)} max token outliers", end=" " + f"[bucket <= {max_bucket_duration:.2f}s] [{num_tokens_bucket.min()} - {num_tokens_bucket.max()}] [approx-max-tps: {max_tps_bucket:.2f}] Discarded {binidx - bin_indexes[-1] - len(num_tokens_bucket)} max token outliers", + end=" ", ) if len(outlier_tps) > 0: print(f"min-outlier: {outlier_tps.min():.2f}, max-outlier: {outlier_tps.max():.2f}).", end="") diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index 8e95641fc7ef..4b4deb092bea 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -408,7 +408,9 @@ def oomptimizer( ( "text" if any( - isinstance(item["type"], NeuralType) and isinstance(item["type"].elements_type, LabelsType) and item["seq_length"] == direction + isinstance(item["type"], NeuralType) + and isinstance(item["type"].elements_type, LabelsType) + and item["seq_length"] == direction for item in schema["inputs"] if item["type"] != "dummy" ) From ab2f4d66b8242010facf16b5f705d36ba28c3da7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 6 Jan 2025 13:39:54 -0500 Subject: [PATCH 06/18] Support list values for max_tps/max_tpt and bucketing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 2 ++ .../common/data/lhotse/sampling.py | 32 +++++++++++++++++-- .../common/test_2d_bucketing_constraint.py | 20 ++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index c7a772d56e82..fab106981f9e 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -632,6 +632,7 @@ def determine_sampling_constraint(cuts: CutSet, bucket_duration_bins, config) -> batch_sizes=config.bucket_batch_size, token_equivalent_duration=config.token_equivalent_duration, strict_2d=config.bucketing_2d_strict_mode, + max_ratio=config.max_tpt if isinstance(config.max_tpt, Sequence) else None, ) cuts = cuts.filter(BucketingFilter(constraint)) else: @@ -650,6 +651,7 @@ def determine_sampling_constraint(cuts: CutSet, bucket_duration_bins, config) -> max_seq_len_buckets=bucket_duration_bins, batch_sizes=config.bucket_batch_size, strict_2d=config.bucketing_2d_strict_mode, + max_ratio=config.max_tps if isinstance(config.max_tps, Sequence) else None, ) cuts = cuts.filter(BucketingFilter(constraint)) else: diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index c2d3dcd8be37..9de896a862f1 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -115,13 +115,23 @@ class FixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint): When ``strict_2d`` is set, we only consider sub-buckets for a single bucket that is the best match. When set to ``False``, we'll promote an example to buckets with larger 1st dim if they can accommodate the 2nd dim. + + When ``max_ratio`` is set, it discards the examples that exceed a specific output-to-input length ratio. + ``max_ratio`` must be a list with the same length as the number of buckets. + ``max_ratio`` is only applied when ``strict_2d`` is set to ``True``. """ strict_2d: bool = True + max_ratio: list[float] | None = None def __post_init__(self): if isinstance(self.max_seq_len_buckets[0], Sequence): self.max_seq_len_buckets = np.asarray(self.max_seq_len_buckets) + if self.max_ratio is not None: + assert isinstance(self.max_ratio, Sequence), f"self.max_ratio must be a list, but we got: {self.max_ratio}" + assert len(self.max_ratio) == len( + self.max_seq_len_buckets + ), f"{len(self.max_ratio)=} != {len(self.max_seq_len_buckets)=}" @property def bucketing_2d_enabled(self) -> bool: @@ -136,11 +146,16 @@ def measure_length(self, example: Cut) -> tuple[float, float] | float: def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = None) -> int: if example_len is None: example_len = self.measure_length(example) - return find_smallest_bucket(self.max_seq_len_buckets, example_len, strict=self.strict_2d) + return find_smallest_bucket( + self.max_seq_len_buckets, example_len, strict=self.strict_2d, max_ratio=self.max_ratio + ) def find_smallest_bucket( - buckets: np.ndarray, example_lens: float | Sequence[float], strict: bool = True + buckets: np.ndarray, + example_lens: float | Sequence[float], + strict: bool = True, + max_ratio: Sequence[float] | None = None, ) -> int | None: """ Find the smallest bucket that fits a given example. @@ -161,16 +176,23 @@ def find_smallest_bucket( # and example_lens = (8, 11) # we will return None because we only consider the first two buckets based on dim0 (=8). if strict: + # Find the first 2D bucket that accepts this example dim0_begin = bisect_left(buckets[:, 0], example_lens[0]) if dim0_begin == buckets.shape[0]: return None + # Find the last 2D bucket that accepts this example dim0_end = dim0_begin while dim0_end < buckets.shape[0] and buckets[dim0_end, 0] == buckets[dim0_begin, 0]: dim0_end += 1 + # Find the smallest 2D bucket in this range that accepts this example dim1_begin = bisect_left(buckets[dim0_begin:dim0_end, 1], example_lens[1]) if dim1_begin == dim0_end - dim0_begin: return None - return dim0_begin + dim1_begin + fit_idx = dim0_begin + dim1_begin + # Apply max_ratio (token-per-second/token-per-token) filtering if requested + if max_ratio is not None and example_lens[1] / example_lens[0] > max_ratio[fit_idx]: + return None + return fit_idx # 2D bucketing 'lenient' mode - linear search (as 2nd dim may not be growing monotonically). # E.g. for buckets = [(10, 5), (10, 10), (20, 12), (20, 18)] @@ -294,6 +316,8 @@ class TokenPerSecondFilter: def __init__(self, tps_min: float | None, tps_max: float | None) -> None: self.tps_min = ifnone(tps_min, -1) + if isinstance(tps_max, Sequence): + tps_max = float("inf") # filtering handled in bucketing filter self.tps_max = ifnone(tps_max, float("inf")) assert tps_min <= tps_max, f"{tps_min=} {tps_max=}" self.enabled = tps_min > 0 or tps_max < float("inf") @@ -314,6 +338,8 @@ class TokenPerTokenFilter: def __init__(self, tpt_min: float | None, tpt_max: float | None) -> None: self.tpt_min = ifnone(tpt_min, -1) + if isinstance(tpt_max, Sequence): + tpt_max = float("inf") # filtering handled in bucketing filter self.tpt_max = ifnone(tpt_max, float("inf")) assert tpt_min <= tpt_max, f"{tpt_min=} {tpt_max=}" self.enabled = tpt_min > 0 or tpt_max < float("inf") diff --git a/tests/collections/common/test_2d_bucketing_constraint.py b/tests/collections/common/test_2d_bucketing_constraint.py index 21bda0ce8c3f..285df28d4ab8 100644 --- a/tests/collections/common/test_2d_bucketing_constraint.py +++ b/tests/collections/common/test_2d_bucketing_constraint.py @@ -214,6 +214,26 @@ def test_2d_bucketing_filter_strict(duration, num_tokens, should_keep, bucket_id assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == bucket_idx +def test_2d_bucketing_filter_strict_max_ratio(): + buckets = [(5.0, 10), (5.0, 20), (10.0, 15), (10.0, 30)] + max_ratio = [4.0, 4.0, 3.0, 3.0] + batch_sizes = [4, 3, 2, 1] + + # Without max_ratio it works because both dims fit bucket at idx 1 + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=True) + filter_2d = BucketingFilter(constraint) + cut = make_cut(duration=2.0, num_tokens=20) + assert filter_2d(cut) == True + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == 1 + + # With max_ratio it's filtered out because 20 / 2.0 = 10.0 but max_ratio is 4.0 + constraint = FixedBucketBatchSizeConstraint2D(buckets, batch_sizes, strict_2d=True, max_ratio=max_ratio) + filter_2d = BucketingFilter(constraint) + cut = make_cut(duration=2.0, num_tokens=20) + assert filter_2d(cut) == False + assert constraint.select_bucket(constraint.max_seq_len_buckets, cut) == None + + class _Identity(torch.utils.data.Dataset): def __getitem__(self, item): return item From c9f0b7194484ab340515a50791a2632bf9185165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 6 Jan 2025 10:50:15 -0800 Subject: [PATCH 07/18] fix max_tps/max_tpt config parsing for list values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index fab106981f9e..2a5e56a950e9 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -119,7 +119,7 @@ class LhotseDataLoadingConfig: min_duration: float | None = -1 max_duration: float | None = float("inf") min_tps: int = -1 # allowed tokens per second (audio-only) - max_tps: float = float("inf") + max_tps: Any = float("inf") # float | list[float] # * Text input min_tokens: int | None = None max_tokens: int | None = None @@ -127,7 +127,7 @@ class LhotseDataLoadingConfig: # For 2D bucketing it's always false, as we report a tuple of (context_len, answer_len). measure_total_length: bool = True min_tpt: int = -1 # allowed tokens per token (text-only) - max_tpt: float = float("inf") + max_tpt: Any = float("inf") # float | list[float] # 3. Supported existing NeMo options. shuffle: bool = False From d8b2d750c8ead2680a6a683020a20dd4d9efb526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 9 Jan 2025 12:59:44 -0500 Subject: [PATCH 08/18] Transition: tarred_random_access -> skip_missing_manifest_entries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 8 ++-- .../common/data/lhotse/dataloader.py | 20 ++++----- .../common/data/lhotse/nemo_adapters.py | 42 +++++++++---------- .../estimate_duration_bins_2d.py | 3 +- 4 files changed, 35 insertions(+), 38 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 63e93d8cf860..b2c74c16065a 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -190,7 +190,7 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: "force_finite": config.get("force_finite", False), "max_open_streams": config.get("max_open_streams", None), "token_equivalent_duration": config.get("token_equivalent_duration", None), - "tarred_random_access": config.get("tarred_random_access", False), + "skip_missing_manifest_entries": config.get("skip_missing_manifest_entries", False), } input_cfg = config.input_cfg if isinstance(input_cfg, (str, Path)): @@ -510,11 +510,11 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: LazyNeMoTarredIterator( config.manifest_filepath, tar_paths=config.tarred_audio_filepaths, - tarred_random_access=config.tarred_random_access, + skip_missing_manifest_entries=config.skip_missing_manifest_entries, **common_kwargs, ) ) - if not config.tarred_random_access and not force_finite: + if not force_finite: cuts = cuts.repeat() else: cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs)) @@ -552,7 +552,7 @@ def read_nemo_manifest(config) -> tuple[CutSet, bool]: nemo_iter = LazyNeMoTarredIterator( manifest_path=manifest_path, tar_paths=tar_path, - tarred_random_access=config.tarred_random_access, + skip_missing_manifest_entries=config.skip_missing_manifest_entries, **common_kwargs, ) else: diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 2a5e56a950e9..cffc708af871 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -77,7 +77,8 @@ class LhotseDataLoadingConfig: cuts_path: str | None = None shar_path: Any = None # str | list[str | tuple[str, float | int]] | None = None # Enable this to support dataloading from JSON manifests that reference subsets of audio tar files. - tarred_random_access: bool = False + skip_missing_manifest_entries: bool = False + tarred_random_access: bool = False # deprecated, replaced by: skip_missing_manifest_entries # 2. Batch size. # a. Existing NeMo options. batch_size: int | None = None @@ -708,13 +709,20 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi if not isinstance(config, DictConfig): config = DictConfig(config) + if config.get("tarred_random_access", False): + warnings.warn( + "Option 'tarred_random_access' is deprecated and replaced with 'skip_missing_manifest_entries'.", + category=DeprecationWarning, + ) + config.skip_missing_manifest_entries = True + # Remove unsupported keys and warn about them. supported_keys = set(OmegaConf.to_container(default).keys()) received_keys = set(OmegaConf.to_container(config).keys()) unsupported_keys = received_keys - supported_keys if unsupported_keys: warnings.warn( - f"The following configuration keys are no longer supported " f"and ignored: {','.join(unsupported_keys)}", + f"The following configuration keys are no longer supported and ignored: {','.join(unsupported_keys)}", category=DeprecationWarning, ) config = OmegaConf.masked_copy(config, list(supported_keys)) @@ -722,14 +730,6 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi return OmegaConf.merge(default, config) -def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: - assert not ( - config.force_map_dataset and config.force_iterable_dataset - ), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True" - use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset - return use_iterable_dataset - - def tokenize(example, tokenizer): if isinstance(example, Cut): for s in example.supervisions: diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index a34a2c074a11..ce05c177154e 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -223,6 +223,11 @@ class LazyNeMoTarredIterator: This can be used for other cloud storage APIs such as S3, GCS, etc. The same mechanism applies to ``manifest_path``. + If your data has been filtered so that the JSON manifests refer to just a subset of recordings, + set ``skip_missing_manifest_entries` to ``True``. + This will still read the tar files sequentially (very fast) and discard the audio files that + are not present in the corresponding manifest. + The ``shard_seed`` argument is used to seed the RNG shuffling the shards. By default, it's ``trng`` which samples a seed number from OS-provided TRNG (see Python ``secrets`` module). Seed is resolved lazily so that every dataloading worker may sample a different one. @@ -264,10 +269,10 @@ def __init__( shard_seed: int | Literal["trng", "randomized"] = "trng", text_field: str = "text", lang_field: str = "lang", - tarred_random_access: bool = False, + skip_missing_manifest_entries: bool = False, extra_fields: list[dict[str, str]] | None = None, ) -> None: - self.tarred_random_access = tarred_random_access + self.skip_missing_manifest_entries = skip_missing_manifest_entries self.shard_id_to_manifest: dict[int, Iterable[dict]] self.paths = expand_sharded_filepaths(manifest_path) if len(self.paths) == 1: @@ -346,29 +351,21 @@ def _validate(self) -> None: def shard_ids(self) -> List[int]: return sorted(self.shard_id_to_manifest.keys()) - def _iter_random_read(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]: - with tarfile.open(fileobj=BytesIO(open_best(tar_path, mode="rb").read()), mode="r") as tar: - for data in shard_manifest: + def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]: + with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar: + for tar_info in tar: try: - tar_info = tar.getmember(data) + data = shard_manifest[tar_info.name] raw_audio = tar.extractfile(tar_info).read() yield data, raw_audio, tar_info except KeyError as e: - raise RuntimeError( - f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " - f"The following audio_filepath='{data['audio_filepath']}' was not found in the tar file." - ) from e - - def _iter_sequential(self, tar_path, shard_manifest, manifest_path) -> Generator[tuple[dict, bytes], None, None]: - with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar: - for tar_info in tar: - assert tar_info.name in shard_manifest, ( - f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " - f"Cannot locate JSON entry for tar file '{tar_info.name}'" - ) - data = shard_manifest[tar_info.name] - raw_audio = tar.extractfile(tar_info).read() - yield data, raw_audio, tar_info + if self.skip_missing_manifest_entries: + continue + else: + raise RuntimeError( + f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " + f"Cannot locate JSON entry for tar file '{tar_info.name}'" + ) from e def __iter__(self) -> Generator[Cut, None, None]: shard_ids = self.shard_ids @@ -384,7 +381,6 @@ def __iter__(self) -> Generator[Cut, None, None]: # They have multiple JSONL entries where audio paths end with '-sub1', '-sub2', etc. for each offset. offset_pattern = re.compile(r'^(?P.+)(?P-sub\d+)(?P\.\w+)?$') - iter_fn = self._iter_random_read if self.tarred_random_access else self._iter_sequential for sid in shard_ids: manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0] @@ -398,7 +394,7 @@ def basename(d: dict) -> str: shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid]) tar_path = self.shard_id_to_tar_path[sid] try: - for data, raw_audio, tar_info in iter_fn(tar_path, shard_manifest, manifest_path): + for data, raw_audio, tar_info in self._iter_sequential(tar_path, shard_manifest, manifest_path): meta = soundfile.info(BytesIO(raw_audio)) recording = Recording( id=tar_info.path, diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 1c57eb2c583d..14e1732a59cf 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -355,10 +355,11 @@ def main(): tps_thresholds = "[" + ",".join(f"{t:.2f}" for t in tps_thresholds) + "]" if not args.quiet: duration_filter.print_report() - print("Use the following options in your config:") + print("Use the following options in your config:") print(f"\tuse_bucketing=1") print(f"\tnum_buckets={args.buckets}") print(f"\tbucket_duration_bins={duration_bins}") + print(f"The max_tps setting below is optional, use it if your data has low quality long transcript outliers:") print(f"\tmax_tps={tps_thresholds}") From d8f4dc7c2f901e5de55ac6097dcb1116fa868540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 9 Jan 2025 13:23:28 -0500 Subject: [PATCH 09/18] Improve warning messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/dataloader.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index cffc708af871..c540965682e6 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -710,20 +710,25 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi config = DictConfig(config) if config.get("tarred_random_access", False): - warnings.warn( + logging.warning( "Option 'tarred_random_access' is deprecated and replaced with 'skip_missing_manifest_entries'.", - category=DeprecationWarning, ) config.skip_missing_manifest_entries = True + if config.skip_missing_manifest_entries: + logging.warning( + "Note: skip_missing_manifest_entries is set to True. " + "If any of your manifests and tar files are mismatched, the entire tar file will be skipped without warning. " + "It's your responsibility to ensure data integrity with this setting." + ) + # Remove unsupported keys and warn about them. supported_keys = set(OmegaConf.to_container(default).keys()) received_keys = set(OmegaConf.to_container(config).keys()) unsupported_keys = received_keys - supported_keys if unsupported_keys: - warnings.warn( - f"The following configuration keys are no longer supported and ignored: {','.join(unsupported_keys)}", - category=DeprecationWarning, + logging.warning( + f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}", ) config = OmegaConf.masked_copy(config, list(supported_keys)) From bff9690c2a58192e6e167f978b040ed06d4b2e94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 10 Jan 2025 08:03:34 -0800 Subject: [PATCH 10/18] Support single unified BPE tokenizer for Canary2 Signed-off-by: Piotr Zelasko --- nemo/collections/common/prompts/canary2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nemo/collections/common/prompts/canary2.py b/nemo/collections/common/prompts/canary2.py index 3aed7a3bfa10..2aa657d294cc 100644 --- a/nemo/collections/common/prompts/canary2.py +++ b/nemo/collections/common/prompts/canary2.py @@ -26,6 +26,7 @@ CANARY_BOS, CANARY_EOS, CANARY_SPECIAL_TOKENIZER, + CanaryTokenizer, ) @@ -196,8 +197,13 @@ def canary2(cut: Cut, prompt: Canary2PromptFormatter) -> dict[str, torch.Tensor] ), ) ans = prompt.encode_dialog(turns) + if isinstance(prompt.tokenizer, CanaryTokenizer): + eos = prompt.tokenizer.eos + else: # SPE + eos = prompt.tokenizer.token_to_id(CANARY_EOS) + assert eos > -1, "Invalid tokenizer: tokenizer.token_to_id('{CANARY_EOS}') returned {eos}" assert ( - ans["answer_ids"][-1].item() == prompt.tokenizer.eos + ans["answer_ids"][-1].item() == eos ), f"Expected the last token in answer_ids to be EOS, but we got {ans['answer_ids']}" ans["answer_ids"] = ans["answer_ids"][:-1] # Strip Canary's EOS return ans From cab0f8bdc7b31b182ad257bbac8d177f2fa29ead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 10 Jan 2025 10:20:30 -0800 Subject: [PATCH 11/18] bugfix +es to unified tokenizer support for Canary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 6 ++++- .../asr/models/aed_multitask_models.py | 12 +++++++-- .../common/data/lhotse/dataloader.py | 26 ++++++++++--------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index f40dffb79467..6ee7bec09ab4 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -71,8 +71,12 @@ def __init__( super().__init__() self.tokenizer = tokenizer self.load_audio = AudioSamples(fault_tolerant=True) - self.padding_value = self.tokenizer.pad self.prompt = prompt + pad_id = self.tokenizer.pad_id + if pad_id == -1: + pad_id = self.tokenizer.token_to_id("") + assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('') returned -1." + self.padding_value = pad_id def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index a609eeaccf9e..e02ebb7e9d87 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -211,8 +211,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) # Define autoregressive CE loss + pad_id = self.tokenizer.pad_id + if pad_id == -1: + pad_id = self.tokenizer.token_to_id("") + assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('') returned -1." with open_dict(self.cfg.loss): - self.cfg.loss.pad_id = self.tokenizer.pad_id + self.cfg.loss.pad_id = pad_id self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss) @@ -387,8 +391,12 @@ def change_vocabulary( self.cfg.decoding = decoding_cfg # Setup loss + pad_id = self.tokenizer.pad_id + if pad_id == -1: + pad_id = self.tokenizer.token_to_id("") + assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('') returned -1." with open_dict(self.cfg.loss): - self.cfg.loss.pad_id = self.tokenizer.pad_id + self.cfg.loss.pad_id = pad_id del self.loss self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index c540965682e6..b17fa51a2660 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -709,12 +709,24 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi if not isinstance(config, DictConfig): config = DictConfig(config) + # Remove unsupported keys and warn about them. + supported_keys = set(OmegaConf.to_container(default).keys()) + received_keys = set(OmegaConf.to_container(config).keys()) + unsupported_keys = received_keys - supported_keys + unsupported_keys.discard("use_lhotse") + if unsupported_keys: + logging.warning( + f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}", + ) + config = OmegaConf.masked_copy(config, list(supported_keys)) + + config = OmegaConf.merge(default, config) + if config.get("tarred_random_access", False): logging.warning( "Option 'tarred_random_access' is deprecated and replaced with 'skip_missing_manifest_entries'.", ) config.skip_missing_manifest_entries = True - if config.skip_missing_manifest_entries: logging.warning( "Note: skip_missing_manifest_entries is set to True. " @@ -722,17 +734,7 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi "It's your responsibility to ensure data integrity with this setting." ) - # Remove unsupported keys and warn about them. - supported_keys = set(OmegaConf.to_container(default).keys()) - received_keys = set(OmegaConf.to_container(config).keys()) - unsupported_keys = received_keys - supported_keys - if unsupported_keys: - logging.warning( - f"The following configuration keys are ignored by Lhotse dataloader: {','.join(unsupported_keys)}", - ) - config = OmegaConf.masked_copy(config, list(supported_keys)) - - return OmegaConf.merge(default, config) + return config def tokenize(example, tokenizer): From acf5ea4fb2316e80a38a59fa54729201a0eab8e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 12 Jan 2025 17:22:00 -0500 Subject: [PATCH 12/18] CanaryBPETokenizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/asr/parts/mixins/mixins.py | 8 +++++-- .../common/tokenizers/canary_tokenizer.py | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index 25ade32fffd8..47fbddae5edc 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -110,8 +110,12 @@ def _setup_monolingual_tokenizer(self, tokenizer_cfg: DictConfig): if special_tokens is not None: raise ValueError("`special_tokens` are no longer supported for SentencePiece based tokenizers.") - # Update special tokens - self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path) + if "custom_tokenizer" in self.tokenizer_cfg: + self.tokenizer = self.from_config_dict( + {"_target_": tokenizer_cfg["custom_tokenizer"]["_target_"], "model_path": model_path} + ) + else: + self.tokenizer = tokenizers.SentencePieceTokenizer(model_path=model_path) if 'vocab_path' in self.tokenizer_cfg: vocab_path = self.tokenizer_cfg.get('vocab_path') diff --git a/nemo/collections/common/tokenizers/canary_tokenizer.py b/nemo/collections/common/tokenizers/canary_tokenizer.py index 04dc6e3a68a9..c0972e5c8c63 100644 --- a/nemo/collections/common/tokenizers/canary_tokenizer.py +++ b/nemo/collections/common/tokenizers/canary_tokenizer.py @@ -191,6 +191,29 @@ def build_special_tokenizer( return spl_tokenizer +class CanaryBPETokenizer(SentencePieceTokenizer): + """ + Thin wrapper around SPE tokenizer that overwrites SPE's BOS/EOS/PAD with Canary's special tokens + for compatibility with CanaryTokenizer (aggregate). + """ + + @cached_property + def eos_id(self) -> int: + return self.token_to_id(CANARY_EOS) + + @cached_property + def bos_id(self) -> int: + return self.token_to_id(CANARY_BOS) + + @cached_property + def nospeech_id(self) -> int: + return self.token_to_id(CANARY_NOSPEECH) + + @cached_property + def pad_id(self) -> int: + return self.token_to_id(CANARY_PAD) + + def _map_canary1_to_canary2_lang(lang: str, available_langs: list[str]) -> str: if len(lang) != 2 or lang in available_langs: return lang From 9fced0e6d7af901b7eecb8a4a9466674708aee80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 12 Jan 2025 14:28:28 -0800 Subject: [PATCH 13/18] Revert changes to fetching pad id in AED models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 6 +----- nemo/collections/asr/models/aed_multitask_models.py | 12 ++---------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 6ee7bec09ab4..1bb1410555d6 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -71,12 +71,8 @@ def __init__( super().__init__() self.tokenizer = tokenizer self.load_audio = AudioSamples(fault_tolerant=True) + self.padding_value = self.tokenizer.pad_id self.prompt = prompt - pad_id = self.tokenizer.pad_id - if pad_id == -1: - pad_id = self.tokenizer.token_to_id("") - assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('') returned -1." - self.padding_value = pad_id def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index e02ebb7e9d87..a609eeaccf9e 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -211,12 +211,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) # Define autoregressive CE loss - pad_id = self.tokenizer.pad_id - if pad_id == -1: - pad_id = self.tokenizer.token_to_id("") - assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('') returned -1." with open_dict(self.cfg.loss): - self.cfg.loss.pad_id = pad_id + self.cfg.loss.pad_id = self.tokenizer.pad_id self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss) @@ -391,12 +387,8 @@ def change_vocabulary( self.cfg.decoding = decoding_cfg # Setup loss - pad_id = self.tokenizer.pad_id - if pad_id == -1: - pad_id = self.tokenizer.token_to_id("") - assert pad_id > -1, "Invalid tokenizer: both tokenizer.pad_id and tokenizer.token_to_id('') returned -1." with open_dict(self.cfg.loss): - self.cfg.loss.pad_id = pad_id + self.cfg.loss.pad_id = self.tokenizer.pad_id del self.loss self.loss = EncDecMultiTaskModel.from_config_dict(self.cfg.loss) From 4c1372851db5cd9c584ca30b0935086530102998 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 13 Jan 2025 17:31:33 -0800 Subject: [PATCH 14/18] Update docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- docs/source/asr/datasets.rst | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index 8298567ff7cc..586bedc03c32 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -1079,23 +1079,22 @@ To run 2D bucketing with 30 buckets sub-divided into 5 sub-buckets each (150 buc # The script's output: Use the following options in your config: + use_bucketing=1 num_buckets=30 bucket_duration_bins=[[1.91,10],[1.91,17],[1.91,25],... - max_duration=... - max_tps=... - + The max_tps setting below is optional, use it if your data has low quality long transcript outliers: + max_tps=[13.2,13.2,11.8,11.8,...] Note that the output in ``bucket_duration_bins`` is a nested list, where every bin specifies the maximum duration and the maximum number of tokens that go into the bucket. Passing this option to Lhotse dataloader will automatically enable 2D bucketing. -Note the presence of ``max_duration`` and ``max_tps`` (token-per-second) options: -these need to be included in dataloader's configuration to ensure we can use the buckets correctly at runtime -in case of outliers. -In general, if you change your data in training, it is highly advisable to re-estimate the duration bins. - -Note that reasonable values for tokens-per-second rarely exceed 12tps with reasonably good tokenizers. -If you find your dataset's TPS is much higher than that, you may have some bad data outliers. -In that case you may specify ``--max_tps`` option to discard those both in bin estimation and dataloading. + +Note the presence of ``max_tps`` (token-per-second) option. +It is optional to include it in the dataloader configuration: if you do, we will apply an extra filter +that discards examples which have more tokens per second than the threshold value. +The threshold is determined for each bucket separately based on data distribution, and can be controlled +with the option ``--token_outlier_threshold``. +This filtering is useful primarily for noisy datasets to discard low quality examples / outliers. We also support aggregate tokenizers for 2D bucketing estimation: From ebc720f27034855f9dbede3a6f056dc465ac5686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 14 Jan 2025 15:11:23 -0800 Subject: [PATCH 15/18] fix tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 9de896a862f1..f5b1a2987754 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -165,7 +165,7 @@ def find_smallest_bucket( Returns a tuple of (smallest_bin, bin_idx), or (None, None) if no bucket fits the example. """ # 1D bucketing - binary search. - if isinstance(example_lens, float): # 1-D + if isinstance(example_lens, (float, int)): # 1-D idx = bisect_left(buckets, example_lens) if idx == len(buckets): return None From a238da22ba591e429f46b19eb40f9f4a835f83da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 16 Jan 2025 17:23:53 -0800 Subject: [PATCH 16/18] Remove max_tps and max_duration from OOMptimizer output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- scripts/speech_recognition/oomptimizer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index 4b4deb092bea..d46179742ff8 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -520,8 +520,6 @@ def step(): if is_2d_bucketing: # 2D bucketing doesn't support bucket merging. final_profile = [["[" + ",".join(map(str, b)) + "]", bs] for (b, _, __), bs in profile.items()] - max_input_len, max_output_len = buckets[-1] - ratio = max_output_len / max_input_len else: click.echo("Bucket merging stage...") final_profile = [] @@ -534,7 +532,6 @@ def step(): final_profile[-1][0] = bucket continue final_profile.append([bucket, bs]) - max_input_len = final_profile[-1][0] click.secho(f"The profile was created with the following settings:") click.secho(f"* using {memory_fraction:.1%} of available GPU RAM.") @@ -543,9 +540,6 @@ def step(): click.secho("The final profile is:", bold=True) click.secho("\tbucket_duration_bins=[" + ",".join(str(seqlen) for seqlen, bs in final_profile) + "]", bold=True) click.secho("\tbucket_batch_size=[" + ",".join(str(bs) for seqlen, bs in final_profile) + "]", bold=True) - click.secho("\t(The following flags are suitable for ASR/speech-to-text models):") - click.secho(f"\tmax_tps={ratio}", bold=True) - click.secho(f"\tmax_duration={max_input_len}", bold=True) if __name__ == "__main__": From 0ab2c0943c1212023f84bbf0b777a00ba61f8e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 21 Jan 2025 12:48:17 -0500 Subject: [PATCH 17/18] Fix integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- scripts/speech_recognition/estimate_duration_bins_2d.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 14e1732a59cf..5f4f0f0a1c11 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -314,7 +314,11 @@ def main(): tokenizer = None prompt = None if args.tokenizer is not None: - tokenizer = load_tokenizer(args.tokenizer, args.langs, 'canary' in args.prompt_format) + tokenizer = load_tokenizer( + paths=args.tokenizer, + langs=args.langs, + is_canary=args.prompt_format is not None and 'canary' in args.prompt_format, + ) if args.prompt_format is not None: prompt_defaults = None if args.prompt is not None: From e9e17e2833dab46fc762b1fec481a8af344b7995 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 22 Jan 2025 07:00:53 -0800 Subject: [PATCH 18/18] fix estimate_duration_bins_2d.py for BPE tokenizer without prompt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- scripts/speech_recognition/estimate_duration_bins_2d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 5f4f0f0a1c11..8f9b74cb6bd2 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -37,6 +37,7 @@ SentencePieceTokenizer, TokenizerSpec, ) +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper def parse_args(): @@ -275,7 +276,7 @@ def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): cut.supervisions[0].tokens = encoded["input_ids"] elif tokenizer is not None: - cut = tokenize(cut, tokenizer) + cut = tokenize(cut, TokenizerWrapper(tokenizer)) return cut