Skip to content

Commit

Permalink
misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
freewym committed Nov 8, 2020
1 parent 4fad398 commit 3f8f008
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 33 deletions.
29 changes: 22 additions & 7 deletions espresso/data/asr_k2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import os
import re
from typing import Dict, List
from typing import Any, Dict, List, Optional

import numpy as np

Expand All @@ -17,12 +17,23 @@
import espresso.tools.utils as speech_utils
try:
# TODO use pip install once it's available
from espresso.tools.lhotse.cut import CutSet
from espresso.tools.lhotse.lhotse import CutSet
except ImportError:
raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools")


def collate(samples, pad_to_length=None, pad_to_multiple=1):
def collate(
samples: List[Dict[str, Any]],
pad_to_length: Optional[Dict[str, int]] = None,
pad_to_multiple: int = 1,
) -> Dict[str, Any]:
"""Collate samples into a batch. We use :func:`speech_utils.collate_frames`
to collate and pad input frames, and PyTorch's :func:`default_collate`
to collate and pad target/supervisions (following the example provided in Lhotse).
Samples in the batch are in descending order of their input frame lengths.
It also allows to specify the padded input length and further enforce the length
to be a multiple of `pad_to_multiple`
"""
if len(samples) == 0:
return {}

Expand Down Expand Up @@ -97,7 +108,7 @@ class AsrK2Dataset(FairseqDataset):
A K2 Dataset for ASR.
Args:
cuts (lhotse.CutSet): Lhotse CutSet to wrap
cuts (lhotse.CutSet): instance of Lhotse's CutSet to wrap
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
pad_to_multiple (int, optional): pad src lengths to a multiple of this value
Expand Down Expand Up @@ -165,14 +176,18 @@ def __getitem__(self, index):
def __len__(self):
return len(self.cuts)

def collater(self, samples, pad_to_length=None):
def collater(
self,
samples: List[Dict[str, Any]],
pad_to_length: Optional[Dict[str, int]] = None,
) -> Dict[str, Any]:
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
pad_to_length (dict, optional): a dictionary of
{"source": source_pad_to_length}
to indicate the max length to pad to in source and target respectively.
to indicate the max length to pad to in source.
Returns:
dict: a mini-batch with the following keys:
Expand All @@ -188,7 +203,7 @@ def collater(self, samples, pad_to_length=None):
- `src_lengths` (IntTensor): 1D Tensor of the unpadded
lengths of each source sequence of shape `(bsz)`
- `target` (List[Dict[str, Any]]): an List representing a batch of
- `target` (List[Dict[str, Any]]): a List representing a batch of
supervisions
"""
return collate(
Expand Down
2 changes: 1 addition & 1 deletion espresso/tasks/speech_recognition_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class SpeechRecognitionHybridConfig(FairseqDataclass):
def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1):
try:
# TODO use pip install once it's available
from espresso.tools.lhotse.cut import CutSet
from espresso.tools.lhotse.lhotse import CutSet
except ImportError:
raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools")

Expand Down
71 changes: 47 additions & 24 deletions examples/mobvoihotwords/local/data_prep.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import sys
from typing import List
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

Expand All @@ -18,11 +19,11 @@

try:
# TODO use pip install once it's available
from espresso.tools.lhotse import (
CutSet, Mfcc, MfccConfig, LilcomFilesWriter, SupervisionSet, WavAugmenter
from espresso.tools.lhotse.lhotse import (
CutSet, Mfcc, MfccConfig, LilcomFilesWriter, RecordingSet, SupervisionSet, WavAugmenter
)
from espresso.tools.lhotse.manipulation import combine
from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords
from espresso.tools.lhotse.lhotse.manipulation import combine
from espresso.tools.lhotse.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords
except ImportError:
raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools")

Expand All @@ -33,7 +34,7 @@
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
logger = logging.getLogger("mobvoihotwords.data_prep")


def get_parser():
Expand Down Expand Up @@ -64,11 +65,25 @@ def main(args):
corpus_dir = root_dir / "MobvoiHotwords"
output_dir = root_dir

# Download and extract the corpus
logger.info(f"Download and extract the corpus")
download_and_untar(root_dir)

# Prepare manifests
mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir)
logger.info(f"Prepare the manifests")
partitions = ["train", "dev", "test"]
if all(
(output_dir / f"{key}_{part}.json").is_file()
for key in ["recordings", "supervisions"] for part in partitions
):
logger.info(f"All the manifests files are found in {output_dir}. Load from them directly")
mobvoihotwords_manifests = defaultdict(dict)
for part in partitions:
mobvoihotwords_manifests[part] = {
"recordings": RecordingSet.from_json(output_dir / f"recordings_{part}.json"),
"supervisions": SupervisionSet.from_json(output_dir / f"supervisions_{part}.json")
}
else:
logger.info("It may take long time")
mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir)
logger.info(
"train/dev/test size: {}/{}/{}".format(
len(mobvoihotwords_manifests["train"]["recordings"]),
Expand All @@ -81,16 +96,17 @@ def main(args):
np.random.seed(args.seed)
# equivalent to Kaldi's mfcc_hires config
mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400))
for partition, manifests in mobvoihotwords_manifests.items():
for part, manifests in mobvoihotwords_manifests.items():
cut_set = CutSet.from_manifests(
recordings=manifests["recordings"],
supervisions=manifests["supervisions"],
)
sampling_rate = next(iter(cut_set)).sampling_rate
with ProcessPoolExecutor(args.num_jobs) as ex:
if "train" in partition:
if part == "train":
# split negative recordings into smaller chunks with lengths sampled from
# length distribution of positive recordings
logger.info(f"Split negative recordings in '{part}' set")
pos_durs = get_positive_durations(manifests["supervisions"])
with numpy_seed(args.seed):
cut_set = keep_positives_and_split_negatives(
Expand All @@ -100,64 +116,71 @@ def main(args):
overlap_duration=args.overlap_duration,
)
# "clean" set
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_clean") as storage:
logger.info(f"Extract features for '{part}' set")
with LilcomFilesWriter(f"{output_dir}/feats_{part}_clean") as storage:
cut_set_clean = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=None,
executor=ex,
)
# augmented with reverberation
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage:
logger.info(f"Extract features from '{part}' set with reverberation")
with LilcomFilesWriter(f"{output_dir}/feats_{part}_rev") as storage:
with numpy_seed(args.seed):
cut_set_rev = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=WavAugmenter(effect_chain=reverb()),
excutor=ex,
augmenter=WavAugmenter(effect_chain=reverb(), sampling_rate=sampling_rate),
executor=ex,
)
cut_set_rev = CutSet.from_cuts(
cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts
)
# augmented with speed perturbation
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage:
logger.info(f"Extract features from '{part}' set with speed perturbation")
with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp1.1") as storage:
cut_set_sp1p1 = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=WavAugmenter(
effect_chain=speed(sampling_rate=sampling_rate, factor=1.1)
effect_chain=speed(sampling_rate=sampling_rate, factor=1.1),
sampling_rate=sampling_rate,
),
excutor=ex,
executor=ex,
)
cut_set_sp1p1 = CutSet.from_cuts(
cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts
)
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage:
with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp0.9") as storage:
cut_set_sp0p9 = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=WavAugmenter(
effect_chain=speed(sampling_rate=sampling_rate, factor=0.9)
effect_chain=speed(sampling_rate=sampling_rate, factor=0.9),
sampling_rate=sampling_rate,
),
excutor=ex,
executor=ex,
)
cut_set_sp0p9 = CutSet.from_cuts(
cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts
)
# combine the clean and augmented sets together
logger.info(f"Combine all the features above")
cut_set = combine(
cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9
)
else: # no augmentations for dev and test sets
with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage:
logger.info(f"extract features for '{part}' set")
with LilcomFilesWriter(f"{output_dir}/feats_{part}") as storage:
cut_set = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=None,
executor=ex,
)
mobvoihotwords_manifests[partition]["cuts"] = cut_set
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
mobvoihotwords_manifests[part]["cuts"] = cut_set
cut_set.to_json(output_dir / f"cuts_{part}.json.gz")


def get_positive_durations(sup_set: SupervisionSet) -> List[float]:
Expand All @@ -166,7 +189,7 @@ def get_positive_durations(sup_set: SupervisionSet) -> List[float]:
"FREETEXT" for all negative recordings, and SupervisionSegment.duration
equals to the corresponding Recording.duration.
"""
return [sup.dur for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")]
return [sup.duration for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")]


def keep_positives_and_split_negatives(
Expand Down
3 changes: 2 additions & 1 deletion examples/mobvoihotwords/path.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin
export LC_ALL=C
# END

export PATH=~/anaconda3/bin:$PATH
export PATH=/export/b03/ywang/anaconda3/bin:$PATH
export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH
export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/openfst/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/lhotse/tools/deps/sox-code/src/.libs:$LD_LIBRARY_PATH
export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$MAIN_ROOT/espresso/tools/lhotse:$MAIN_ROOT/espresso/tools/pychain:$PYTHONPATH
export PYTHONUNBUFFERED=1

0 comments on commit 3f8f008

Please sign in to comment.