From 3f8f00837d48add7b4ea65629eb7bd4d3a60908d Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 8 Nov 2020 03:03:32 -0500 Subject: [PATCH] misc fixes --- espresso/data/asr_k2_dataset.py | 29 +++++++-- espresso/tasks/speech_recognition_hybrid.py | 2 +- examples/mobvoihotwords/local/data_prep.py | 71 ++++++++++++++------- examples/mobvoihotwords/path.sh | 3 +- 4 files changed, 72 insertions(+), 33 deletions(-) mode change 100644 => 100755 examples/mobvoihotwords/local/data_prep.py diff --git a/espresso/data/asr_k2_dataset.py b/espresso/data/asr_k2_dataset.py index 2c22b44dfc..6ac57e9a35 100644 --- a/espresso/data/asr_k2_dataset.py +++ b/espresso/data/asr_k2_dataset.py @@ -6,7 +6,7 @@ import logging import os import re -from typing import Dict, List +from typing import Any, Dict, List, Optional import numpy as np @@ -17,12 +17,23 @@ import espresso.tools.utils as speech_utils try: # TODO use pip install once it's available - from espresso.tools.lhotse.cut import CutSet + from espresso.tools.lhotse.lhotse import CutSet except ImportError: raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") -def collate(samples, pad_to_length=None, pad_to_multiple=1): +def collate( + samples: List[Dict[str, Any]], + pad_to_length: Optional[Dict[str, int]] = None, + pad_to_multiple: int = 1, +) -> Dict[str, Any]: + """Collate samples into a batch. We use :func:`speech_utils.collate_frames` + to collate and pad input frames, and PyTorch's :func:`default_collate` + to collate and pad target/supervisions (following the example provided in Lhotse). + Samples in the batch are in descending order of their input frame lengths. + It also allows to specify the padded input length and further enforce the length + to be a multiple of `pad_to_multiple` + """ if len(samples) == 0: return {} @@ -97,7 +108,7 @@ class AsrK2Dataset(FairseqDataset): A K2 Dataset for ASR. Args: - cuts (lhotse.CutSet): Lhotse CutSet to wrap + cuts (lhotse.CutSet): instance of Lhotse's CutSet to wrap shuffle (bool, optional): shuffle dataset elements before batching (default: True). pad_to_multiple (int, optional): pad src lengths to a multiple of this value @@ -165,14 +176,18 @@ def __getitem__(self, index): def __len__(self): return len(self.cuts) - def collater(self, samples, pad_to_length=None): + def collater( + self, + samples: List[Dict[str, Any]], + pad_to_length: Optional[Dict[str, int]] = None, + ) -> Dict[str, Any]: """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {"source": source_pad_to_length} - to indicate the max length to pad to in source and target respectively. + to indicate the max length to pad to in source. Returns: dict: a mini-batch with the following keys: @@ -188,7 +203,7 @@ def collater(self, samples, pad_to_length=None): - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - - `target` (List[Dict[str, Any]]): an List representing a batch of + - `target` (List[Dict[str, Any]]): a List representing a batch of supervisions """ return collate( diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index cf995774d1..5ee9d77f81 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -151,7 +151,7 @@ class SpeechRecognitionHybridConfig(FairseqDataclass): def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1): try: # TODO use pip install once it's available - from espresso.tools.lhotse.cut import CutSet + from espresso.tools.lhotse.lhotse import CutSet except ImportError: raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") diff --git a/examples/mobvoihotwords/local/data_prep.py b/examples/mobvoihotwords/local/data_prep.py old mode 100644 new mode 100755 index 821c228a8e..87ba9b2737 --- a/examples/mobvoihotwords/local/data_prep.py +++ b/examples/mobvoihotwords/local/data_prep.py @@ -9,6 +9,7 @@ import os import sys from typing import List +from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from pathlib import Path @@ -18,11 +19,11 @@ try: # TODO use pip install once it's available - from espresso.tools.lhotse import ( - CutSet, Mfcc, MfccConfig, LilcomFilesWriter, SupervisionSet, WavAugmenter + from espresso.tools.lhotse.lhotse import ( + CutSet, Mfcc, MfccConfig, LilcomFilesWriter, RecordingSet, SupervisionSet, WavAugmenter ) - from espresso.tools.lhotse.manipulation import combine - from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords + from espresso.tools.lhotse.lhotse.manipulation import combine + from espresso.tools.lhotse.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords except ImportError: raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") @@ -33,7 +34,7 @@ level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) -logger = logging.getLogger(__name__) +logger = logging.getLogger("mobvoihotwords.data_prep") def get_parser(): @@ -64,11 +65,25 @@ def main(args): corpus_dir = root_dir / "MobvoiHotwords" output_dir = root_dir - # Download and extract the corpus + logger.info(f"Download and extract the corpus") download_and_untar(root_dir) - # Prepare manifests - mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) + logger.info(f"Prepare the manifests") + partitions = ["train", "dev", "test"] + if all( + (output_dir / f"{key}_{part}.json").is_file() + for key in ["recordings", "supervisions"] for part in partitions + ): + logger.info(f"All the manifests files are found in {output_dir}. Load from them directly") + mobvoihotwords_manifests = defaultdict(dict) + for part in partitions: + mobvoihotwords_manifests[part] = { + "recordings": RecordingSet.from_json(output_dir / f"recordings_{part}.json"), + "supervisions": SupervisionSet.from_json(output_dir / f"supervisions_{part}.json") + } + else: + logger.info("It may take long time") + mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) logger.info( "train/dev/test size: {}/{}/{}".format( len(mobvoihotwords_manifests["train"]["recordings"]), @@ -81,16 +96,17 @@ def main(args): np.random.seed(args.seed) # equivalent to Kaldi's mfcc_hires config mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400)) - for partition, manifests in mobvoihotwords_manifests.items(): + for part, manifests in mobvoihotwords_manifests.items(): cut_set = CutSet.from_manifests( recordings=manifests["recordings"], supervisions=manifests["supervisions"], ) sampling_rate = next(iter(cut_set)).sampling_rate with ProcessPoolExecutor(args.num_jobs) as ex: - if "train" in partition: + if part == "train": # split negative recordings into smaller chunks with lengths sampled from # length distribution of positive recordings + logger.info(f"Split negative recordings in '{part}' set") pos_durs = get_positive_durations(manifests["supervisions"]) with numpy_seed(args.seed): cut_set = keep_positives_and_split_negatives( @@ -100,7 +116,8 @@ def main(args): overlap_duration=args.overlap_duration, ) # "clean" set - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_clean") as storage: + logger.info(f"Extract features for '{part}' set") + with LilcomFilesWriter(f"{output_dir}/feats_{part}_clean") as storage: cut_set_clean = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, @@ -108,56 +125,62 @@ def main(args): executor=ex, ) # augmented with reverberation - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage: + logger.info(f"Extract features from '{part}' set with reverberation") + with LilcomFilesWriter(f"{output_dir}/feats_{part}_rev") as storage: with numpy_seed(args.seed): cut_set_rev = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, - augmenter=WavAugmenter(effect_chain=reverb()), - excutor=ex, + augmenter=WavAugmenter(effect_chain=reverb(), sampling_rate=sampling_rate), + executor=ex, ) cut_set_rev = CutSet.from_cuts( cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts ) # augmented with speed perturbation - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage: + logger.info(f"Extract features from '{part}' set with speed perturbation") + with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp1.1") as storage: cut_set_sp1p1 = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, augmenter=WavAugmenter( - effect_chain=speed(sampling_rate=sampling_rate, factor=1.1) + effect_chain=speed(sampling_rate=sampling_rate, factor=1.1), + sampling_rate=sampling_rate, ), - excutor=ex, + executor=ex, ) cut_set_sp1p1 = CutSet.from_cuts( cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts ) - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage: + with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp0.9") as storage: cut_set_sp0p9 = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, augmenter=WavAugmenter( - effect_chain=speed(sampling_rate=sampling_rate, factor=0.9) + effect_chain=speed(sampling_rate=sampling_rate, factor=0.9), + sampling_rate=sampling_rate, ), - excutor=ex, + executor=ex, ) cut_set_sp0p9 = CutSet.from_cuts( cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts ) # combine the clean and augmented sets together + logger.info(f"Combine all the features above") cut_set = combine( cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 ) else: # no augmentations for dev and test sets - with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage: + logger.info(f"extract features for '{part}' set") + with LilcomFilesWriter(f"{output_dir}/feats_{part}") as storage: cut_set = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, augmenter=None, executor=ex, ) - mobvoihotwords_manifests[partition]["cuts"] = cut_set - cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") + mobvoihotwords_manifests[part]["cuts"] = cut_set + cut_set.to_json(output_dir / f"cuts_{part}.json.gz") def get_positive_durations(sup_set: SupervisionSet) -> List[float]: @@ -166,7 +189,7 @@ def get_positive_durations(sup_set: SupervisionSet) -> List[float]: "FREETEXT" for all negative recordings, and SupervisionSegment.duration equals to the corresponding Recording.duration. """ - return [sup.dur for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")] + return [sup.duration for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")] def keep_positives_and_split_negatives( diff --git a/examples/mobvoihotwords/path.sh b/examples/mobvoihotwords/path.sh index a2576bef6f..180398bce8 100644 --- a/examples/mobvoihotwords/path.sh +++ b/examples/mobvoihotwords/path.sh @@ -9,8 +9,9 @@ export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin export LC_ALL=C # END -export PATH=~/anaconda3/bin:$PATH +export PATH=/export/b03/ywang/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/openfst/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/lhotse/tools/deps/sox-code/src/.libs:$LD_LIBRARY_PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$MAIN_ROOT/espresso/tools/lhotse:$MAIN_ROOT/espresso/tools/pychain:$PYTHONPATH export PYTHONUNBUFFERED=1