Skip to content

Commit

Permalink
add random split of negatives
Browse files Browse the repository at this point in the history
  • Loading branch information
freewym committed Nov 7, 2020
1 parent 7dd7fee commit fa07059
Showing 1 changed file with 114 additions and 19 deletions.
133 changes: 114 additions & 19 deletions examples/mobvoihotwords/local/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import os
import sys
from typing import List
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

Expand All @@ -31,7 +32,15 @@ def get_parser():
parser.add_argument("--data-dir", default="data", type=str, help="data directory")
parser.add_argument("--seed", default=1, type=int, help="random seed")
parser.add_argument(
"--nj", default=1, type=int, help="number of jobs for features extraction"
"--num-jobs", default=1, type=int, help="number of jobs for features extraction"
)
parser.add_argument(
"--max-remaining-duration", default=0.3, type=float,
help="not split if the left-over duration is less than this many seconds"
)
parser.add_argument(
"--overlap-duration", default=0.3, type=float,
help="overlap between adjacent segments while splitting negative recordings"
)
# fmt: on

Expand All @@ -41,7 +50,9 @@ def get_parser():
def main(args):
try:
# TODO use pip install once it's available
from espresso.tools.lhotse import CutSet, Mfcc, MfccConfig, LilcomFilesWriter, WavAugmenter
from espresso.tools.lhotse import (
CutSet, Mfcc, MfccConfig, LilcomFilesWriter, SupervisionSet, WavAugmenter
)
from espresso.tools.lhotse.manipulation import combine
from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords
except ImportError:
Expand All @@ -68,36 +79,46 @@ def main(args):
np.random.seed(args.seed)
# equivalent to Kaldi's mfcc_hires config
mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400))
num_jobs = args.nj
for partition, manifests in mobvoihotwords_manifests.items():
cut_set = CutSet.from_manifests(
recordings=manifests["recordings"],
supervisions=manifests["supervisions"],
)
sampling_rate = next(iter(cut_set)).sampling_rate
with ProcessPoolExecutor(num_jobs) as ex:
with ProcessPoolExecutor(args.num_jobs) as ex:
if "train" in partition:
# original set
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_orig") as storage:
cut_set_orig = cut_set.compute_and_store_features(
# split negative recordings into smaller chunks with lengths sampled from
# length distribution of positive recordings
pos_durs = get_positive_durations(manifests["supervisions"])
with numpy_seed(args.seed):
cut_set = keep_positives_and_split_negatives(
cut_set,
pos_durs,
max_remaining_duration=args.max_remaining_duration,
overlap_duration=args.overlap_duration,
)
# "clean" set
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_clean") as storage:
cut_set_clean = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=None,
executor=ex,
)
# augmented with reverbration
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage:
cut_set_rev = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=WavAugmenter(effect_chain=reverb()),
excutor=ex,
)
# augmented with reverberation
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage:
with numpy_seed(args.seed):
cut_set_rev = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
augmenter=WavAugmenter(effect_chain=reverb()),
excutor=ex,
)
cut_set_rev = CutSet.from_cuts(
cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts
)
# augmented with speed perturbation
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage:
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage:
cut_set_sp1p1 = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
Expand All @@ -109,7 +130,7 @@ def main(args):
cut_set_sp1p1 = CutSet.from_cuts(
cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts
)
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage:
with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage:
cut_set_sp0p9 = cut_set.compute_and_store_features(
extractor=mfcc,
storage=storage,
Expand All @@ -121,9 +142,9 @@ def main(args):
cut_set_sp0p9 = CutSet.from_cuts(
cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts
)
# combine the original and augmented sets together
# combine the clean and augmented sets together
cut_set = combine(
cut_set_orig, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9
cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9
)
else: # no augmentations for dev and test sets
with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage:
Expand All @@ -137,6 +158,80 @@ def main(args):
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")


def get_positive_durations(sup_set: SupervisionSet) -> List[float]:
"""
Get duration values of all positive recordings. Assume Supervison.text is
"FREETEXT" for all negative recordings, and SupervisionSegment.duration
equals to the corresponding Recording.duration.
"""
return [sup.dur for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")]


def keep_positives_and_split_negatives(
cut_set: CutSet,
durations: List[float],
max_remaining_duration: float = 0.3,
overlap_duration: float = 0.3,
) -> CutSet:
"""
Returns a new CutSet where all the positives are directly taken from the original
input cut set, and the negatives are obtained by splitting original negatives
into shorter chunks of random lengths drawn from the given length distribution
(here it is the empirical distribution of the positive recordings), There can
be overlap between chunks.
Args:
cut_set (CutSet): original input cut set
durations (list[float]): list of durations to sample from
max_remaining_duration (float, optional): not split if the left-over
duration is less than this many seconds (default: 0.3).
overlap_duration (float, optional): overlap between adjacent segments
(default: None)
Returns:
CutSet: a new cut set after split
"""
assert max_remaining_duration >= 0.0 and overlap_duration >= 0.0
new_cuts = []
for cut in cut_set:
assert len(cut.supervisions) == 1
if cut.supervisions[0].text != "FREETEXT": # keep the positive as it is
new_cuts.append(cut)
else:
this_offset = cut.start
this_offset_relative = this_offset - cut.start
remaining_duration = cut.duration
this_dur = durations[np.random.randint(len(durations))]
while remaining_duration > this_dur + max_remaining_duration:
new_cut = cut.truncate(
offset=this_offset_relative, duration=this_dur, preserve_id=True
)
new_cut = new_cut.with_id(
"{id}-{s:07d}-{e:07d}".format(
id=new_cut.id,
s=int(round(100 * this_offset_relative)),
e=int(round(100 * (this_offset_relative + this_dur)))
)
)
new_cuts.append(new_cut)
this_offset += this_dur - overlap_duration
this_offset_relative = this_offset - cut.start
remaining_duration -= this_dur - overlap_duration
this_dur = durations[np.random.randint(len(durations))]

new_cut = cut.truncate(offset=this_offset_relative, preserve_id=True)
new_cut = new_cut.with_id(
"{id}-{s:07d}-{e:07d}".format(
id=new_cut.id,
s=int(round(100 * this_offset_relative)),
e=int(round(100 * cut.duration))
)
)
new_cuts.append(new_cut)

return CutSet.from_cuts(new_cuts)


def reverb(*args, **kwargs):
"""
Returns a reverb effect for wav augmentation.
Expand Down

0 comments on commit fa07059

Please sign in to comment.