From 5d1001d12e9a9d0151040bd79f828a8a10246ead Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 1 Dec 2018 21:59:05 -0500 Subject: [PATCH 001/119] initial commit --- .gitignore | 3 + fairseq/data/__init__.py | 9 + fairseq/data/scp_dataset.py | 233 ++++++++++ fairseq/data/speech_dataset.py | 244 +++++++++++ fairseq/data/token_dictionary.py | 57 +++ fairseq/tasks/speech_recognition.py | 135 ++++++ speech_tools/kaldi_io.py | 630 ++++++++++++++++++++++++++++ speech_tools/utils.py | 75 ++++ tests/test_speech_dataset.py | 176 ++++++++ tests/test_speech_utils.py | 104 +++++ 10 files changed, 1666 insertions(+) create mode 100644 fairseq/data/scp_dataset.py create mode 100644 fairseq/data/speech_dataset.py create mode 100644 fairseq/data/token_dictionary.py create mode 100644 fairseq/tasks/speech_recognition.py create mode 100644 speech_tools/kaldi_io.py create mode 100644 speech_tools/utils.py create mode 100644 tests/test_speech_dataset.py create mode 100644 tests/test_speech_utils.py diff --git a/.gitignore b/.gitignore index 411280479..fbe71542a 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,6 @@ experimental/* # Weights and Biases logs wandb/ + +# emacs saves +*~ diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index 9b3081395..ae7c0f1c9 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -56,6 +56,9 @@ from .multilingual.sampled_multi_dataset import SampledMultiDataset from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset from .fasta_dataset import FastaDataset, EncodedFastaDataset +from .token_dictionary import TokenDictionary +from .scp_dataset import ScpDataset, ScpCachedDataset, ScpInMemoryDataset, TokenTextDataset +from .speech_dataset import SpeechDataset from .iterators import ( CountingIterator, @@ -121,4 +124,10 @@ "TransformEosLangPairDataset", "TruncateDataset", "TruncatedDictionary", + 'TokenDictionary', + 'ScpDataset', + 'ScpCachedDataset', + 'ScpInMemoryDataset', + 'TokenTextDataset', + 'SpeechDataset', ] diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py new file mode 100644 index 000000000..761a59b96 --- /dev/null +++ b/fairseq/data/scp_dataset.py @@ -0,0 +1,233 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import os + +import numpy as np +import torch + +import speech_tools.kaldi_io as kaldi_io +from speech_tools.utils import Tokenizer + +class ScpDataset(torch.utils.data.Dataset): + """Loader for TorchNet IndexedDataset""" + + def __init__(self, path): + super().__init__() + self.dtype = np.float + self.read_scp(path) + + def read_scp(self, path): + with open(path, 'r', encoding='utf-8') as f: + scp_entries = [line.strip().split(None, 1) for line in f] + self.utt_ids = [entry[0] for entry in scp_entries] + self.extended_filenames = [entry[1] for entry in scp_entries] + self.size = len(scp_entries) # number of utterances + self.sizes=[] # length of each utterance + for filename in self.extended_filenames: + try: + feat = kaldi_io.read_mat(filename) + except: + print('Failed to read feature matrix {}.'.format(filename)) + raise + assert feat is not None and isinstance(feat, np.ndarray) + self.sizes.append(feat.shape[0]) + self.sizes = np.array(self.sizes, dtype=np.int32) + self.feat_dim = feat.shape[1] # feature dimension + + assert len(self.utt_ids) == len(self.extended_filenames) and \ + len(self.utt_ids) == len(self.sizes) + + + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError('index out of range') + + def filter_and_reorder(self, indices): + assert isinstance(indices, (list, np.ndarray)) + indices = np.array(indices) + assert all(indices < len(self.utt_ids)) and all(indices >= 0) + assert len(np.unique(indices)) == len(indices), \ + 'Duplicate elements in indices.' + self.utt_ids = [self.utt_ids[i] for i in indices] + self.extended_filenames = [self.extended_filenames[i] for i in indices] + self.sizes = self.sizes[indices] + self.size = len(self.utt_ids) + self.ordered_indices = list(range(self.size)) + + def __getitem__(self, i): + self.check_index(i) + feat = kaldi_io.read_mat(self.extended_filenames[i]) + item = torch.from_numpy(feat).float() + return item + + def __len__(self): + return self.size + + @staticmethod + def exists(path): + return os.path.exists(path) + + +class ScpCachedDataset(ScpDataset): + + def __init__(self, path, ordered_prefetch=False, cache_size=4096): + super().__init__(path) + self.cache = None + self.cache_index = {} + self.cache_size = cache_size # in terms of number of examples + self.start_search_for_next_pos_start = 0 + self.ordered_indices = list(range(self.size)) + self.ordered_prefetch = ordered_prefetch # set to True ONLY if examples + # are queried in the same order + # as self.ordered_indices, and + # doing this will speed up + # search of the queried index. + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + """Sets self.ordered_indices. If being called, the caller is supposed to + query examples in the same order as self.ordered_indices. + self.ordered_prefetch can be set to True in this case. Note: the purpose + of this function is different from what it is supposed to do in the + fairseq framework.""" + assert isinstance(indices, (list, np.ndarray)) + assert self.size >= len(indices) + self.ordered_indices = indices.copy() + + def __getitem__(self, i): + self.check_index(i) + if i not in self.cache_index: + assert self.start_search_for_next_pos_start < \ + len(self.ordered_indices), \ + 'Search position starting beyond the end of ordered_indices.' + try: + pos_start = self.ordered_indices.index(i, + self.start_search_for_next_pos_start) + except ValueError: + print('index {} not found in self.ordered_indices. Set ' + 'self.ordered_prefetch to False, and/or call self.prefetch() ' + 'with the full list of indices, and then try again.'.format(i)) + raise + pos_end = min(pos_start + self.cache_size, + len(self.ordered_indices)) + self.start_search_for_next_pos_start = pos_end \ + if self.ordered_prefetch else 0 + total_size = 0 + for idx in self.ordered_indices[pos_start : pos_end]: + total_size += self.sizes[idx] + self.cache = np.empty((total_size, self.feat_dim), dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for idx in self.ordered_indices[pos_start : pos_end]: + self.cache_index[idx] = ptx + length = self.sizes[idx] + dst = self.cache[ptx : ptx + length] + np.copyto(dst, kaldi_io.read_mat(self.extended_filenames[idx])) + ptx += length + + ptx = self.cache_index[i] + a = self.cache[ptx : ptx + self.sizes[i]].copy() + return torch.from_numpy(a).float() + + +class ScpInMemoryDataset(ScpDataset): + """Loader for TorchNet ScpDataset, keeps all the data in memory.""" + + def __init__(self, path): + super().__init__(path) + self.read_data() + + def read_data(self): + self.data_offsets = np.append([0], np.cumsum(self.sizes)[:-1]) + self.buffer = np.empty((sum(self.sizes), self.feat_dim), + dtype=self.dtype) + for i in range(len(self.data_offsets)): + ptx = self.data_offsets[i] + dst = self.buffer[ptx : ptx + self.sizes[i]] + np.copyto(dst, kaldi_io.read_mat(self.extended_filenames[i])) + + def filter_and_reorder(self, indices): + super().filter_and_reorder(indices) + self.read_data() + + def __getitem__(self, i): + self.check_index(i) + ptx = self.data_offsets[i] + a = self.buffer[ptx : ptx + self.sizes[i]].copy() + return torch.from_numpy(a).float() + + +class TokenTextDataset(torch.utils.data.Dataset): + """Takes a text file as input and binarizes it in memory at instantiation. + Original lines are also kept in memory. Each line of the text file is in the + format of 'utt_id tokenized_text'.""" + + def __init__(self, path, dictionary, append_eos=True): + super().__init__() + self.dtype = np.float + self.append_eos = append_eos + self.read_text(path, dictionary) + + def read_text(self, path, dictionary): + self.utt_ids = [] + self.tokens_list = [] + self.tensor_list = [] + self.sizes = [] + with open(path, 'r', encoding='utf-8') as f: + for line in f: + utt_id, tokens = line.strip().split(None, 1) + self.utt_ids.append(utt_id) + self.tokens_list.append(tokens) + tensor = Tokenizer.tokens_to_index_tensor(tokens, dictionary) + self.tensor_list.append(tensor) + self.sizes.append(len(self.tensor_list[-1])) + + self.size = len(self.utt_ids) # number of utterances + self.sizes = np.array(self.sizes, dtype=np.int32) + + assert len(self.utt_ids) == len(self.tokens_list) and \ + len(self.utt_ids) == len(self.tensor_list) and \ + len(self.utt_ids) == len(self.sizes) + + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError('index out of range') + + def filter_and_reorder(self, indices): + assert isinstance(indices, (list, np.ndarray)) + indices = np.array(indices) + assert all(indices < self.size) and all(indices >= 0) + assert len(np.unique(indices)) == len(indices), \ + 'Duplicate elements in indices.' + self.utt_ids = [self.utt_ids[i] for i in indices] + self.tokens_list = [self.tokens_list[i] for i in indices] + self.tensor_list = [self.tensor_list[i] for i in indices] + self.sizes = self.sizes[indices] + self.size = len(self.utt_ids) + + def __getitem__(self, i): + self.check_index(i) + return self.tensor_list[i] + + def get_original_tokens(self, i): + self.check_index(i) + return self.tokens_list[i] + + def get_original_text(self, i): + self.check_index(i) + return Tokenizer.tokens_to_sentence(self.tokens_list[i]) + + def __len__(self): + return self.size + + @staticmethod + def exists(path): + return os.path.exists(path) diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py new file mode 100644 index 000000000..06aa78e78 --- /dev/null +++ b/fairseq/data/speech_dataset.py @@ -0,0 +1,244 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import numpy as np +import torch + +from fairseq import utils + +from . import data_utils, FairseqDataset +import speech_tools.utils as speech_utils + + +def collate( + samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, + input_feeding=True, +): + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False): + if key == 'source': + return speech_utils.collate_frames( + [s[key] for s in samples], 0.0, left_pad, + ) + elif key == 'target': + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, eos_idx, left_pad, move_eos_to_beginning, + ) + else: + raise ValueError('Invalid key.') + + id = torch.LongTensor([s['id'] for s in samples]) + utt_id = [s['utt_id'] for s in samples] + src_frames = merge('source', left_pad=left_pad_source) + # sort by descending source length + src_lengths = torch.IntTensor([s['source'].size(0) for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + utt_id = [utt_id[i] for i in sort_order.numpy()] + src_frames = src_frames.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get('target', None) is not None: + target = merge('target', left_pad=left_pad_target) + target = target.index_select(0, sort_order) + ntokens = sum(len(s['target']) for s in samples) + + if input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + 'target', + left_pad=left_pad_target, + move_eos_to_beginning=True, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s['source']) for s in samples) + + batch = { + 'id': id, + 'utt_id': utt_id, + 'nsentences': len(samples), + 'ntokens': ntokens, + 'net_input': { + 'src_tokens': src_frames, # key name kept due to + # FairseqModel::forward(...,src_tokens,...) + 'src_lengths': src_lengths, + }, + 'target': target, + } + if prev_output_tokens is not None: + batch['net_input']['prev_output_tokens'] = prev_output_tokens + return batch + + +class SpeechDataset(FairseqDataset): + """ + A pair of torch.utils.data.Datasets. + + Args: + src (torch.utils.data.Dataset): source dataset to wrap + src_sizes (List[int]): source sentence lengths + tgt (torch.utils.data.Dataset, optional): target dataset to wrap + tgt_sizes (List[int], optional): target sentence lengths + dict (~fairseq.data.Dictionary, optional): target vocabulary + left_pad_source (bool, optional): pad source tensors on the left side. + Default: ``True`` + left_pad_target (bool, optional): pad target tensors on the left side. + Default: ``False`` + max_source_positions (int, optional): max number of frames in the + source. Default: ``1024`` + max_target_positions (int, optional): max number of tokens in the target + sentence. Default: ``1024`` + shuffle (bool, optional): shuffle dataset elements before batching. + Default: ``True`` + input_feeding (bool, optional): create a shifted version of the targets + to be passed into the model for input feeding/teacher forcing. + Default: ``True`` + """ + + def __init__( + self, src, src_sizes, + tgt=None, tgt_sizes=None, dict=None, + left_pad_source=True, left_pad_target=False, + max_source_positions=1024, max_target_positions=1024, + shuffle=True, input_feeding=True, + ): + self.src = src + self.tgt = tgt + self.src_sizes = np.array(src_sizes) + self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + self.dict = dict + self.left_pad_source = left_pad_source + self.left_pad_target = left_pad_target + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.shuffle = shuffle + self.input_feeding = input_feeding + if self.tgt is not None: + self._match_src_tgt() + + def _match_src_tgt(self): + """Makes utterances in src and tgt the same order in terms of + their utt_ids. Removes those that only appear in one of them.""" + assert self.tgt is not None + if self.src.utt_ids == self.tgt.utt_ids: + return + tgt_utt_ids_set = set(self.tgt.utt_ids) + src_indices = [i for i, id in enumerate(self.src.utt_ids) \ + if id in tgt_utt_ids_set] + self.src.filter_and_reorder(src_indices) + try: + tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) + except ValueError: + print('Unable to find some utt_id(s) in tgt. which is unlikely to \ + happen. Something must be wrong.') + raise + self.tgt.filter_and_reorder(tgt_indices) + assert self.src.utt_ids == self.tgt.utt_ids + + def __getitem__(self, index): + tgt_item = self.tgt[index] if self.tgt is not None else None + src_item = self.src[index] + return { + 'id': index, + 'utt_id': self.src.utt_ids[index], + 'source': src_item, + 'target': tgt_item, + } + + def __len__(self): + return len(self.src) + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (FloatTensor): a padded 3D Tensor of features in + the source of shape `(bsz, src_len, feat_dim)`. Padding will + appear on the left if *left_pad_source* is ``True``. + - `src_lengths` (IntTensor): 1D Tensor of the unpadded + lengths of each source sequence of shape `(bsz)` + - `prev_output_tokens` (LongTensor): a padded 2D Tensor of + tokens in the target sentence, shifted right by one position + for input feeding/teacher forcing, of shape `(bsz, + tgt_len)`. This key will not be present if *input_feeding* + is ``False``. Padding will appear on the left if + *left_pad_target* is ``True``. + + - `target` (IntTensor): a padded 2D Tensor of tokens in the + target sentence of shape `(bsz, tgt_len)`. Padding will appear + on the left if *left_pad_target* is ``True``. + """ + return collate( + samples, pad_idx=self.dict.pad(), eos_idx=self.dict.eos(), + left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, + input_feeding=self.input_feeding, + ) + + def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): + """Return a dummy batch with a given number of tokens.""" + src_len, tgt_len = utils.resolve_max_positions( + (src_len, tgt_len), + max_positions, + (self.max_source_positions, self.max_target_positions), + ) + bsz = max(num_tokens // tgt_len, 1) + return self.collater([ + { + 'id': i, + 'utt_id': 'dummy' + str(i), + 'source': torch.FloatTensor(src_len, self.src.feat_dim).uniform_(-10.0, 10.0), + 'target': self.dict.dummy_sentence(tgt_len) if self.dict is not None else None, + } + for i in range(bsz) + ]) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.tgt_sizes[index] if self.tgt_sizes is not None else 0 + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] + return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + + def prefetch(self, indices): + """Only prefetch src.""" + self.src.prefetch(indices) + + @property + def supports_prefetch(self): + return ( + hasattr(self.src, 'supports_prefetch') + and self.src.supports_prefetch + ) diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py new file mode 100644 index 000000000..026ceb0fb --- /dev/null +++ b/fairseq/data/token_dictionary.py @@ -0,0 +1,57 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from fairseq.data import Dictionary + +import torch + + +class TokenDictionary(Dictionary): + """A mapping from symbols to consecutive integers""" + def __init__(self, pad='', eos='', unk='', space=''): + self.unk_word, self.pad_word, self.eos_word, self.space_word = \ + unk, pad, eos, space + self.symbols = [] + self.count = [] + self.indices = {} + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + self.space_index = self.add_symbol(space) + self.nspecial = len(self.symbols) + + def string(self, tensor, bpe_symbol=None, escape_unk=False): + """Helper for converting a tensor of token indices to a string. + + Can optionally remove BPE symbols or escape words. + + We overwrite this since we would like to also ignore . + """ + if torch.is_tensor(tensor) and tensor.dim() == 2: + return '\n'.join(self.string(t) for t in tensor) + + def token_string(i): + if i == self.unk(): + return self.unk_string(escape_unk) + else: + return self[i] + + sent = ' '.join(token_string(i) for i in tensor if i != self.eos() and \ + i != self.pad()) + if bpe_symbol is not None: + sent = (sent + ' ').replace(bpe_symbol, '').rstrip() + return sent + + def space(self): + """Helper to get index of space symbol""" + return self.space_index + + def dummy_sentence(self, length): + # sample starting from space + t = torch.Tensor(length).uniform_(self.nspecial - 1, len(self)).int() + t[-1] = self.eos() + return t diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py new file mode 100644 index 000000000..96858c14d --- /dev/null +++ b/fairseq/tasks/speech_recognition.py @@ -0,0 +1,135 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import itertools +import numpy as np +import os + +from fairseq import options +from fairseq.data import ( + data_utils, TokenDictionary, SpeechDataset, ConcatDataset, + TokenTextDataset, ScpCachedDataset +) + +from . import FairseqTask, register_task + + +@register_task('speech_recognition') +class SpeechRecognitionTask(FairseqTask): + """ + Translate from speech (source) to token text (target). + + Args: + dict (Dictionary): dictionary for the output tokens + + .. note:: + + The speech recognition task is compatible with :mod:`train.py `, + :mod:`generate.py ` and :mod:`interactive.py `. + + The speech_recognition task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.speech_recognition_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument('scp_files', nargs='+', help='path(s) to scp file(s)') + parser.add_argument('text_files', nargs='+', help='path(s) to text file(s)') + parser.add_argument('--dict', default=None, type=str, + help='path to the dictionary') + parser.add_argument('--raw-text', action='store_true', + help='load raw text dataset') + parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', + help='pad the source on the left') + parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', + help='pad the target on the left') + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + parser.add_argument('--upsample-primary', default=1, type=int, + help='amount to upsample primary dataset') + + def __init__(self, args, dict): + super().__init__(args) + self.dict = dict + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + # load dictionaries + dict_path = os.path.join(os.path.dirname(args.text_files[0]), + 'dict.txt') if args.dict is None else args.dict + dict = TokenDictionary.load(dict_path) + print('| dictionary: {} types'.format(len(dict))) + + return cls(args, dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + src_datasets = [] + tgt_datasets = [] + + assert len(self.args.scp_files) == len(self.args.text_files) + file_pairs = zip(self.args.scp_files, self.args.text_file) + for scp, text in enumerate(file_pairs): + assert ScpCachedDataset.exists(scp) and TokenTextDataset.exists(text) + src_datasets.append(ScpCachedDataset(scp, ordered_indices=True)) + tgt_datasets.append(TokenTextDataset(text, self.dict)) + print('| {} {} examples'.format(scp, len(src_datasets[-1]))) + print('| {} {} examples'.format(text, len(tgt_datasets[-1]))) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) + + if len(src_datasets) == 1: + src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] + else: + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = self.args.upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + + self.datasets[split] = SpeechDataset( + src_dataset, src_dataset.sizes, + tgt_dataset, tgt_dataset.sizes, self.dict, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + ) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return None + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.dict diff --git a/speech_tools/kaldi_io.py b/speech_tools/kaldi_io.py new file mode 100644 index 000000000..7cc7a1dec --- /dev/null +++ b/speech_tools/kaldi_io.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2014-2016 Brno University of Technology (author: Karel Vesely) +# Licensed under the Apache License, Version 2.0 (the "License") + +import numpy as np +import sys, os, re, gzip, struct + +################################################# +# Adding kaldi tools to shell path, + +# Select kaldi, +if not 'KALDI_ROOT' in os.environ: + # Default! To change run python with 'export KALDI_ROOT=/some_dir python' + os.environ['KALDI_ROOT']='/mnt/matylda5/iveselyk/Tools/kaldi-trunk' + +# Add kaldi tools to path, +os.environ['PATH'] = os.popen('echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/').readline().strip() + ':' + os.environ['PATH'] + + +################################################# +# Define all custom exceptions, +class UnsupportedDataType(Exception): pass +class UnknownVectorHeader(Exception): pass +class UnknownMatrixHeader(Exception): pass + +class BadSampleSize(Exception): pass +class BadInputFormat(Exception): pass + +class SubprocessFailed(Exception): pass + +################################################# +# Data-type independent helper functions, + +def open_or_fd(file, mode='rb'): + """ fd = open_or_fd(file) + Open file, gzipped file, pipe, or forward the file-descriptor. + Eventually seeks in the 'file' argument contains ':offset' suffix. + """ + offset = None + try: + # strip 'ark:' prefix from r{x,w}filename (optional), + if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:', file): + (prefix,file) = file.split(':',1) + # separate offset from filename (optional), + if re.search(':[0-9]+$', file): + (file,offset) = file.rsplit(':',1) + # input pipe? + if file[-1] == '|': + fd = popen(file[:-1], 'rb') # custom, + # output pipe? + elif file[0] == '|': + fd = popen(file[1:], 'wb') # custom, + # is it gzipped? + elif file.split('.')[-1] == 'gz': + fd = gzip.open(file, mode) + # a normal file... + else: + fd = open(file, mode) + except TypeError: + # 'file' is opened file descriptor, + fd = file + # Eventually seek to offset, + if offset != None: fd.seek(int(offset)) + return fd + +# based on '/usr/local/lib/python3.4/os.py' +def popen(cmd, mode="rb"): + if not isinstance(cmd, str): + raise TypeError("invalid cmd type (%s, expected string)" % type(cmd)) + + import subprocess, io, threading + + # cleanup function for subprocesses, + def cleanup(proc, cmd): + ret = proc.wait() + if ret > 0: + raise SubprocessFailed('cmd %s returned %d !' % (cmd,ret)) + return + + # text-mode, + if mode == "r": + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return io.TextIOWrapper(proc.stdout) + elif mode == "w": + proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return io.TextIOWrapper(proc.stdin) + # binary, + elif mode == "rb": + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return proc.stdout + elif mode == "wb": + proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return proc.stdin + # sanity, + else: + raise ValueError("invalid mode %s" % mode) + + +def read_key(fd): + """ [key] = read_key(fd) + Read the utterance-key from the opened ark/stream descriptor 'fd'. + """ + key = '' + while 1: + char = fd.read(1).decode("latin1") + if char == '' : break + if char == ' ' : break + key += char + key = key.strip() + if key == '': return None # end of file, + assert(re.match('^\S+$',key) != None) # check format (no whitespace!) + return key + + +################################################# +# Integer vectors (alignments, ...), + +def read_ali_ark(file_or_fd): + """ Alias to 'read_vec_int_ark()' """ + return read_vec_int_ark(file_or_fd) + +def read_vec_int_ark(file_or_fd): + """ generator(key,vec) = read_vec_int_ark(file_or_fd) + Create generator of (key,vector) tuples, which reads from the ark file/stream. + file_or_fd : ark, gzipped ark, pipe or opened file descriptor. + + Read ark to a 'dictionary': + d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + ali = read_vec_int(fd) + yield key, ali + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + +def read_vec_int(file_or_fd): + """ [int-vec] = read_vec_int(file_or_fd) + Read kaldi integer vector, ascii or binary input, + """ + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode() + if binary == '\0B': # binary flag + assert(fd.read(1).decode() == '\4'); # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim + # Elements from int32 vector are sored in tuples: (sizeof(int32), value), + vec = np.frombuffer(fd.read(vec_size*5), dtype=[('size','int8'),('value','int32')], count=vec_size) + assert(vec[0]['size'] == 4) # int32 size, + ans = vec[:]['value'] # values are in 2nd column, + else: # ascii, + arr = (binary + fd.readline().decode()).strip().split() + try: + arr.remove('['); arr.remove(']') # optionally + except ValueError: + pass + ans = np.array(arr, dtype=int) + if fd is not file_or_fd : fd.close() # cleanup + return ans + +# Writing, +def write_vec_int(file_or_fd, v, key=''): + """ write_vec_int(f, v, key='') + Write a binary kaldi integer vector to filename or stream. + Arguments: + file_or_fd : filename or opened file descriptor for writing, + v : the vector to be stored, + key (optional) : used for writing ark-file, the utterance-id gets written before the vector. + + Example of writing single vector: + kaldi_io.write_vec_int(filename, vec) + + Example of writing arkfile: + with open(ark_file,'w') as f: + for key,vec in dict.iteritems(): + kaldi_io.write_vec_flt(f, vec, key=key) + """ + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert(fd.mode == 'wb') + try: + if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), + fd.write('\0B'.encode()) # we write binary! + # dim, + fd.write('\4'.encode()) # int32 type, + fd.write(struct.pack(np.dtype('int32').char, v.shape[0])) + # data, + for i in range(len(v)): + fd.write('\4'.encode()) # int32 type, + fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary, + finally: + if fd is not file_or_fd : fd.close() + + +################################################# +# Float vectors (confidences, ivectors, ...), + +# Reading, +def read_vec_flt_scp(file_or_fd): + """ generator(key,mat) = read_vec_flt_scp(file_or_fd) + Returns generator of (key,vector) tuples, read according to kaldi scp. + file_or_fd : scp, gzipped scp, pipe or opened file descriptor. + + Iterate the scp: + for key,vec in kaldi_io.read_vec_flt_scp(file): + ... + + Read scp to a 'dictionary': + d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } + """ + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key,rxfile) = line.decode().split(' ') + vec = read_vec_flt(rxfile) + yield key, vec + finally: + if fd is not file_or_fd : fd.close() + +def read_vec_flt_ark(file_or_fd): + """ generator(key,vec) = read_vec_flt_ark(file_or_fd) + Create generator of (key,vector) tuples, reading from an ark file/stream. + file_or_fd : ark, gzipped ark, pipe or opened file descriptor. + + Read ark to a 'dictionary': + d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + ali = read_vec_flt(fd) + yield key, ali + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + +def read_vec_flt(file_or_fd): + """ [flt-vec] = read_vec_flt(file_or_fd) + Read kaldi float vector, ascii or binary input, + """ + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode() + if binary == '\0B': # binary flag + return _read_vec_flt_binary(fd) + else: # ascii, + arr = (binary + fd.readline().decode()).strip().split() + try: + arr.remove('['); arr.remove(']') # optionally + except ValueError: + pass + ans = np.array(arr, dtype=float) + if fd is not file_or_fd : fd.close() # cleanup + return ans + +def _read_vec_flt_binary(fd): + header = fd.read(3).decode() + if header == 'FV ' : sample_size = 4 # floats + elif header == 'DV ' : sample_size = 8 # doubles + else : raise UnknownVectorHeader("The header contained '%s'" % header) + assert (sample_size > 0) + # Dimension, + assert (fd.read(1).decode() == '\4'); # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim + # Read whole vector, + buf = fd.read(vec_size * sample_size) + if sample_size == 4 : ans = np.frombuffer(buf, dtype='float32') + elif sample_size == 8 : ans = np.frombuffer(buf, dtype='float64') + else : raise BadSampleSize + return ans + + +# Writing, +def write_vec_flt(file_or_fd, v, key=''): + """ write_vec_flt(f, v, key='') + Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats. + Arguments: + file_or_fd : filename or opened file descriptor for writing, + v : the vector to be stored, + key (optional) : used for writing ark-file, the utterance-id gets written before the vector. + + Example of writing single vector: + kaldi_io.write_vec_flt(filename, vec) + + Example of writing arkfile: + with open(ark_file,'w') as f: + for key,vec in dict.iteritems(): + kaldi_io.write_vec_flt(f, vec, key=key) + """ + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert(fd.mode == 'wb') + try: + if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), + fd.write('\0B'.encode()) # we write binary! + # Data-type, + if v.dtype == 'float32': fd.write('FV '.encode()) + elif v.dtype == 'float64': fd.write('DV '.encode()) + else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % v.dtype) + # Dim, + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim + # Data, + fd.write(v.tobytes()) + finally: + if fd is not file_or_fd : fd.close() + + +################################################# +# Float matrices (features, transformations, ...), + +# Reading, +def read_mat_scp(file_or_fd): + """ generator(key,mat) = read_mat_scp(file_or_fd) + Returns generator of (key,matrix) tuples, read according to kaldi scp. + file_or_fd : scp, gzipped scp, pipe or opened file descriptor. + + Iterate the scp: + for key,mat in kaldi_io.read_mat_scp(file): + ... + + Read scp to a 'dictionary': + d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } + """ + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key,rxfile) = line.decode().split(' ') + mat = read_mat(rxfile) + yield key, mat + finally: + if fd is not file_or_fd : fd.close() + +def read_mat_ark(file_or_fd): + """ generator(key,mat) = read_mat_ark(file_or_fd) + Returns generator of (key,matrix) tuples, read from ark file/stream. + file_or_fd : scp, gzipped scp, pipe or opened file descriptor. + + Iterate the ark: + for key,mat in kaldi_io.read_mat_ark(file): + ... + + Read ark to a 'dictionary': + d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + mat = read_mat(fd) + yield key, mat + key = read_key(fd) + finally: + if fd is not file_or_fd : fd.close() + +def read_mat(file_or_fd): + """ [mat] = read_mat(file_or_fd) + Reads single kaldi matrix, supports ascii and binary. + file_or_fd : file, gzipped file, pipe or opened file descriptor. + """ + fd = open_or_fd(file_or_fd) + try: + binary = fd.read(2).decode() + if binary == '\0B' : + mat = _read_mat_binary(fd) + else: + assert(binary == ' [') + mat = _read_mat_ascii(fd) + finally: + if fd is not file_or_fd: fd.close() + return mat + +def _read_mat_binary(fd): + # Data type + header = fd.read(3).decode() + # 'CM', 'CM2', 'CM3' are possible values, + if header.startswith('CM'): return _read_compressed_mat(fd, header) + elif header == 'FM ': sample_size = 4 # floats + elif header == 'DM ': sample_size = 8 # doubles + else: raise UnknownMatrixHeader("The header contained '%s'" % header) + assert(sample_size > 0) + # Dimensions + s1, rows, s2, cols = np.frombuffer(fd.read(10), dtype='int8,int32,int8,int32', count=1)[0] + # Read whole matrix + buf = fd.read(rows * cols * sample_size) + if sample_size == 4 : vec = np.frombuffer(buf, dtype='float32') + elif sample_size == 8 : vec = np.frombuffer(buf, dtype='float64') + else : raise BadSampleSize + mat = np.reshape(vec,(rows,cols)) + return mat + +def _read_mat_ascii(fd): + rows = [] + while 1: + line = fd.readline().decode() + if (len(line) == 0) : raise BadInputFormat # eof, should not happen! + if len(line.strip()) == 0 : continue # skip empty line + arr = line.strip().split() + if arr[-1] != ']': + rows.append(np.array(arr,dtype='float32')) # not last line + else: + rows.append(np.array(arr[:-1],dtype='float32')) # last line + mat = np.vstack(rows) + return mat + + +def _read_compressed_mat(fd, format): + """ Read a compressed matrix, + see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h + methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...), + """ + assert(format == 'CM ') # The formats CM2, CM3 are not supported... + + # Format of header 'struct', + global_header = np.dtype([('minvalue','float32'),('range','float32'),('num_rows','int32'),('num_cols','int32')]) # member '.format' is not written, + per_col_header = np.dtype([('percentile_0','uint16'),('percentile_25','uint16'),('percentile_75','uint16'),('percentile_100','uint16')]) + + # Read global header, + globmin, globrange, rows, cols = np.frombuffer(fd.read(16), dtype=global_header, count=1)[0] + + # The data is structed as [Colheader, ... , Colheader, Data, Data , .... ] + # { cols }{ size } + col_headers = np.frombuffer(fd.read(cols*8), dtype=per_col_header, count=cols) + col_headers = np.array([np.array([x for x in y]) * globrange * 1.52590218966964e-05 + globmin for y in col_headers], dtype=np.float32) + data = np.reshape(np.frombuffer(fd.read(cols*rows), dtype='uint8', count=cols*rows), newshape=(cols,rows)) # stored as col-major, + + mat = np.zeros((cols,rows), dtype='float32') + p0 = col_headers[:, 0].reshape(-1, 1) + p25 = col_headers[:, 1].reshape(-1, 1) + p75 = col_headers[:, 2].reshape(-1, 1) + p100 = col_headers[:, 3].reshape(-1, 1) + mask_0_64 = (data <= 64) + mask_193_255 = (data > 192) + mask_65_192 = (~(mask_0_64 | mask_193_255)) + + mat += (p0 + (p25 - p0) / 64. * data) * mask_0_64.astype(np.float32) + mat += (p25 + (p75 - p25) / 128. * (data - 64)) * mask_65_192.astype(np.float32) + mat += (p75 + (p100 - p75) / 63. * (data - 192)) * mask_193_255.astype(np.float32) + + return mat.T # transpose! col-major -> row-major, + + +# Writing, +def write_mat(file_or_fd, m, key=''): + """ write_mat(f, m, key='') + Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats. + Arguments: + file_or_fd : filename of opened file descriptor for writing, + m : the matrix to be stored, + key (optional) : used for writing ark-file, the utterance-id gets written before the matrix. + + Example of writing single matrix: + kaldi_io.write_mat(filename, mat) + + Example of writing arkfile: + with open(ark_file,'w') as f: + for key,mat in dict.iteritems(): + kaldi_io.write_mat(f, mat, key=key) + """ + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert(fd.mode == 'wb') + try: + if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), + fd.write('\0B'.encode()) # we write binary! + # Data-type, + if m.dtype == 'float32': fd.write('FM '.encode()) + elif m.dtype == 'float64': fd.write('DM '.encode()) + else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % m.dtype) + # Dims, + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols + # Data, + fd.write(m.tobytes()) + finally: + if fd is not file_or_fd : fd.close() + + +################################################# +# 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...) +# Corresponds to: vector > > +# - outer vector: time axis +# - inner vector: records at the time +# - tuple: int = index, float = value +# + +def read_cnet_ark(file_or_fd): + """ Alias of function 'read_post_ark()', 'cnet' = confusion network """ + return read_post_ark(file_or_fd) + +def read_post_ark(file_or_fd): + """ generator(key,vec>) = read_post_ark(file) + Returns generator of (key,posterior) tuples, read from ark file. + file_or_fd : ark, gzipped ark, pipe or opened file descriptor. + + Iterate the ark: + for key,post in kaldi_io.read_post_ark(file): + ... + + Read ark to a 'dictionary': + d = { key:post for key,post in kaldi_io.read_post_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + post = read_post(fd) + yield key, post + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + +def read_post(file_or_fd): + """ [post] = read_post(file_or_fd) + Reads single kaldi 'Posterior' in binary format. + + The 'Posterior' is C++ type 'vector > >', + the outer-vector is usually time axis, inner-vector are the records + at given time, and the tuple is composed of an 'index' (integer) + and a 'float-value'. The 'float-value' can represent a probability + or any other numeric value. + + Returns vector of vectors of tuples. + """ + fd = open_or_fd(file_or_fd) + ans=[] + binary = fd.read(2).decode(); assert(binary == '\0B'); # binary flag + assert(fd.read(1).decode() == '\4'); # int-size + outer_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) + + # Loop over 'outer-vector', + for i in range(outer_vec_size): + assert(fd.read(1).decode() == '\4'); # int-size + inner_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of records for frame (or bin) + data = np.frombuffer(fd.read(inner_vec_size*10), dtype=[('size_idx','int8'),('idx','int32'),('size_post','int8'),('post','float32')], count=inner_vec_size) + assert(data[0]['size_idx'] == 4) + assert(data[0]['size_post'] == 4) + ans.append(data[['idx','post']].tolist()) + + if fd is not file_or_fd: fd.close() + return ans + + +################################################# +# Kaldi Confusion Network bin begin/end times, +# (kaldi stores CNs time info separately from the Posterior). +# + +def read_cntime_ark(file_or_fd): + """ generator(key,vec>) = read_cntime_ark(file_or_fd) + Returns generator of (key,cntime) tuples, read from ark file. + file_or_fd : file, gzipped file, pipe or opened file descriptor. + + Iterate the ark: + for key,time in kaldi_io.read_cntime_ark(file): + ... + + Read ark to a 'dictionary': + d = { key:time for key,time in kaldi_io.read_post_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + cntime = read_cntime(fd) + yield key, cntime + key = read_key(fd) + finally: + if fd is not file_or_fd : fd.close() + +def read_cntime(file_or_fd): + """ [cntime] = read_cntime(file_or_fd) + Reads single kaldi 'Confusion Network time info', in binary format: + C++ type: vector >. + (begin/end times of bins at the confusion network). + + Binary layout is ' ...' + + file_or_fd : file, gzipped file, pipe or opened file descriptor. + + Returns vector of tuples. + """ + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode(); assert(binary == '\0B'); # assuming it's binary + + assert(fd.read(1).decode() == '\4'); # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) + + data = np.frombuffer(fd.read(vec_size*10), dtype=[('size_beg','int8'),('t_beg','float32'),('size_end','int8'),('t_end','float32')], count=vec_size) + assert(data[0]['size_beg'] == 4) + assert(data[0]['size_end'] == 4) + ans = data[['t_beg','t_end']].tolist() # Return vector of tuples (t_beg,t_end), + + if fd is not file_or_fd : fd.close() + return ans + + +################################################# +# Segments related, +# + +# Segments as 'Bool vectors' can be handy, +# - for 'superposing' the segmentations, +# - for frame-selection in Speaker-ID experiments, +def read_segments_as_bool_vec(segments_file): + """ [ bool_vec ] = read_segments_as_bool_vec(segments_file) + using kaldi 'segments' file for 1 wav, format : ' ' + - t-beg, t-end is in seconds, + - assumed 100 frames/second, + """ + segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1) + # Sanity checks, + assert(len(segs) > 0) # empty segmentation is an error, + assert(len(np.unique([rec[1] for rec in segs ])) == 1) # segments with only 1 wav-file, + # Convert time to frame-indexes, + start = np.rint([100 * rec[2] for rec in segs]).astype(int) + end = np.rint([100 * rec[3] for rec in segs]).astype(int) + # Taken from 'read_lab_to_bool_vec', htk.py, + frms = np.repeat(np.r_[np.tile([False,True], len(end)), False], + np.r_[np.c_[start - np.r_[0, end[:-1]], end-start].flat, 0]) + assert np.sum(end-start) == np.sum(frms) + return frms + diff --git a/speech_tools/utils.py b/speech_tools/utils.py new file mode 100644 index 000000000..ac66f4277 --- /dev/null +++ b/speech_tools/utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import re + +import torch + + +class Tokenizer: + + @staticmethod + def tokenize(sent, space='', non_lang_syms=None): + sent = ' '.join(sent.strip().split()) + + match_pos = [] + if non_lang_syms is not None: + assert isinstance(non_lang_syms, list) + prog = re.compile('|'.join(map(re.escape, non_lang_syms))) + matches = prog.finditer(sent) + for match in matches: + match_pos.append([match.start(), match.end()]) + + tokens = [] + i = 0 + for (start_pos, end_pos) in match_pos: + tokens.extend([token for token in sent[i:start_pos]]) + tokens.append(sent[start_pos:end_pos]) + i = end_pos + tokens.extend([token for token in sent[i:]]) + + tokens = [space if token == ' ' else token for token in tokens] + return ' '.join(tokens) + + @staticmethod + def tokens_to_index_tensor(line, dict, append_eos=True): + tokens = line.strip().split() + ntokens = len(tokens) + ids = torch.IntTensor(ntokens + 1 if append_eos else ntokens) + + for i, token in enumerate(tokens): + ids[i] = dict.index(token) + if append_eos: + ids[ntokens] = dict.eos_index + return ids + + @staticmethod + def tokens_to_sentence(line, dict): + tokens = line.strip().split() + sent = "" + for token in tokens: + if token == dict.space_word: + sent += " " + elif dict.index(token) == dict.unk(): + sent += dict.unk_word + elif token != dict.pad_word and token != dict.eos_word: + sent += token + return sent.strip() + +def collate_frames(values, pad_value=0.0, left_pad=False): + """Convert a list of 2d tensor into a padded 3d tensor.""" + assert values[0].dim() == 2, "expected 2, got " + str(values[0].dim) + length = max(v.size(0) for v in values) + dim = values[0].size(1) + res = values[0].new(len(values), length, dim).fill_(pad_value) + + for i, v in enumerate(values): + dst = res[i][length - v.size(0):, :] if left_pad \ + else res[i][:v.size(0), :] + assert dst.numel() == v.numel() + dst.copy_(v) + return res diff --git a/tests/test_speech_dataset.py b/tests/test_speech_dataset.py new file mode 100644 index 000000000..e59df2ea3 --- /dev/null +++ b/tests/test_speech_dataset.py @@ -0,0 +1,176 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import unittest +import string +import numpy as np +import os + +import torch + +from fairseq.data import ( + SpeechDataset, TokenDictionary, TokenTextDataset, ScpCachedDataset, + ScpInMemoryDataset) + +import speech_tools.kaldi_io as kaldi_io + + +class TestSpeechDataset(unittest.TestCase): + + @staticmethod + def make_dictionary(): + """construct dictionary.""" + d = TokenDictionary() + alphabet = string.ascii_lowercase + for token in alphabet: + d.add_symbol(token) + d.finalize(padding_factor=1) # don't add extra padding symbols + return d + + @staticmethod + def generate_feats(test_dir, num=10, seed=0): + """generate feature matrices.""" + feats = {} + np.random.seed(seed) + with open(os.path.join(test_dir, 'feats.scp'), 'w', + encoding='utf-8') as f: + for i in range(num): + utt_id = 'utt_id_' + str(i) + ark_file = os.path.join(test_dir, 'mat_' + str(i) + '.ark') + f.write(utt_id + ' ' + ark_file + ':0\n') + length = np.random.randint(200, 800) + m = np.random.uniform(-10.0, 10.0, (length, 40)) + feats[utt_id] = m + kaldi_io.write_mat(ark_file, m) + return feats + + @staticmethod + def generate_text_tokens(test_dir, num=10, seed=0): + """generate token text, where utterances are in a (random) different + order from those in feats.scp.""" + text_tokens = {} + alphabet = string.ascii_lowercase + space = '' + vocab = list(alphabet) + vocab.append(space) + np.random.seed(seed) + with open(os.path.join(test_dir, 'text_tokens'), 'w', + encoding='utf-8') as f: + for i in np.random.permutation(range(num)): + utt_id = 'utt_id_' + str(i) + length = np.random.randint(10, 100) + tokens = [vocab[np.random.randint(0, len(vocab))] \ + for _ in range(length)] + if tokens[0] == space: + tokens[0] = vocab[np.random.randint(0, len(vocab) - 1)] + if tokens[-1] == space: + tokens[-1] = vocab[np.random.randint(0, len(vocab) - 1)] + text_tokens[utt_id] = tokens + f.write(utt_id + ' ' + ' '.join(tokens) + '\n') + return text_tokens + + def setUp(self): + self.test_dir = './temp' + os.makedirs(self.test_dir, exist_ok=True) + self.num_audios = 150 + self.num_transripts = 100 + self.batch_size = 8 + self.cache_size = 16 + self.dict = self.make_dictionary() + self.expected_feats = self.generate_feats(self.test_dir, + num=self.num_audios, seed=0) + self.expected_tokens = self.generate_text_tokens(self.test_dir, + num=self.num_transripts, seed=1) + + self.cuda = torch.cuda.is_available() + + def _speech_dataset_helper(self, all_in_memory=False, + ordered_prefetch=False): + if not all_in_memory: + src_dataset = ScpCachedDataset( + path=os.path.join(self.test_dir, 'feats.scp'), + ordered_prefetch=ordered_prefetch, + cache_size=self.cache_size, + ) + else: + src_dataset = ScpInMemoryDataset( + path=os.path.join(self.test_dir, 'feats.scp') + ) + tgt_dataset = TokenTextDataset( + path=os.path.join(self.test_dir, 'text_tokens'), + dictionary=self.dict, + ) + + dataset = SpeechDataset( + src_dataset, src_dataset.sizes, + tgt_dataset, tgt_dataset.sizes, self.dict, + left_pad_source=False, + left_pad_target=False, + max_source_positions=1000, + max_target_positions=200, + ) + + # assume one is a subset of the other + expected_dataset_size = min(self.num_audios, self.num_transripts) + self.assertEqual(len(dataset.src), expected_dataset_size) + self.assertEqual(len(dataset.tgt), expected_dataset_size) + + indices = list(range(expected_dataset_size)) + batch_sampler = [] + for i in range(0, expected_dataset_size, self.batch_size): + batch_sampler.append(indices[i:i+self.batch_size]) + + if not all_in_memory: + dataset.prefetch(indices) + + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + ) + + for i, batch in enumerate(iter(dataloader)): + bsz = batch["nsentences"] + self.assertEqual(bsz, len(batch_sampler[i])) + src_frames = batch["net_input"]["src_tokens"] + src_lengths = batch["net_input"]["src_lengths"] + tgt_tokens = self.dict.string(batch["target"]).split('\n') + tgt_tokens = [line.split(' ') for line in tgt_tokens] + self.assertEqual(bsz, src_frames.size(0)) + self.assertEqual(bsz, src_lengths.numel()) + self.assertEqual(bsz, len(tgt_tokens)) + for j, utt_id in enumerate(batch["utt_id"]): + self.assertTensorEqual( + torch.from_numpy(self.expected_feats[utt_id]).float(), + src_frames[j, :src_lengths[j], :] + ) + self.assertEqual( + self.expected_tokens[utt_id], + tgt_tokens[j], + ) + + def test_speech_dataset_cached_no_ordered_prefetch(self): + self._speech_dataset_helper(all_in_memory=False, ordered_prefetch=False) + + def test_speech_dataset_cached_with_ordered_prefetch(self): + self._speech_dataset_helper(all_in_memory=False, ordered_prefetch=True) + + def test_speech_dataset_all_in_memory(self): + self._speech_dataset_helper(all_in_memory=True) + + def assertTensorEqual(self, t1, t2): + self.assertEqual(t1.size(), t2.size(), "size mismatch") + if (t1.dtype == torch.short or t1.dtype == torch.int or \ + t1.dtype == torch.long) and (t2.dtype == torch.short or \ + t2.dtype == torch.int or t2.dtype == torch.long): + self.assertEqual(t1.ne(t2).long().sum(), 0) + else: + self.assertEqual(t1.allclose(t2,rtol=1e-05, atol=1e-08), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py new file mode 100644 index 000000000..d42c4bcbf --- /dev/null +++ b/tests/test_speech_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import unittest +import string +import numpy as np +import os + +import torch + +from fairseq.data import TokenDictionary + +import speech_tools.utils as utils + + +class TestSpeechUtils(unittest.TestCase): + + @staticmethod + def make_dictionary(vocab, non_lang_syms=[]): + """construct dictionary.""" + assert isinstance(vocab, list) and isinstance(non_lang_syms, list) + d = TokenDictionary() + for token in vocab: + d.add_symbol(token) + for token in non_lang_syms: + d.add_symbol(token) + d.finalize(padding_factor=1) # don't add extra padding symbols + return d + + @staticmethod + def generate_text(vocab, oovs=[], non_lang_syms=[], seed=0): + """generate text of one synthetic sentence.""" + assert isinstance(vocab, list) and isinstance(oovs, list) and \ + isinstance(non_lang_syms, list) + np.random.seed(seed) + sent_len = np.random.randint(2, 30) + sent = '' + for _ in range(sent_len): + if len(non_lang_syms) > 0 and np.random.randint(0, 20) == 0: + word = non_lang_syms[np.random.randint(0, len(non_lang_syms))] + else: + word = '' + word_len = np.random.randint(2, 11) + for _ in range(word_len): + if len(oovs) > 0 and np.random.randint(0, 20) == 0: + word += oovs[np.random.randint(0, len(oovs))] + else: + word += vocab[np.random.randint(0, len(vocab))] + sent += word + ' ' + + sent = ' '.join(sent.strip().split(' ')) + return sent + + def setUp(self): + self.vocab = list(string.ascii_lowercase) + self.oovs = list(string.ascii_uppercase) + self.non_lang_syms = ['', '', ''] + self.num_sentences = 100 + self.dict = self.make_dictionary(self.vocab, + non_lang_syms=self.non_lang_syms, + ) + self.text = [self.generate_text(self.vocab, self.oovs, + self.non_lang_syms, seed=i) for i in range(self.num_sentences)] + + def test_speech_tokenizer(self): + for i, sent in enumerate(self.text): + print('test sentence {}:'.format(i)) + print(sent) + tokens = utils.Tokenizer.tokenize(sent, \ + space=self.dict.space_word, non_lang_syms=self.non_lang_syms) + + # test Tokenizer.tokenize() with Tokenizer.tokens_to_index_tensor() + tensor = utils.Tokenizer.tokens_to_index_tensor(tokens, self.dict, \ + append_eos=True) + reconstructed_tokens = self.dict.string(tensor) + expected_tokens = ' '.join( + [token if self.dict.index(token) != self.dict.unk() else \ + self.dict.unk_word for token in tokens.split(' ')] + ) + self.assertEqual(reconstructed_tokens, expected_tokens) + + # test Tokenizer.tokenize() with Tokenizer.tokens_to_sentence() + reconstructed_sent = utils.Tokenizer.tokens_to_sentence(tokens, + self.dict) + expected_sent = [] + words = sent.split(' ') + for w in words: + if w not in self.non_lang_syms: + new_word = ''.join( + [self.dict.unk_word if c in self.oovs else c for c in w] + ) + expected_sent.append(new_word) + else: + expected_sent.append(w) + expected_sent = ' '.join(expected_sent) + self.assertEqual(reconstructed_sent, expected_sent) + + +if __name__ == "__main__": + unittest.main() From dbb20ad557980ed601dcb949473755c3e553df77 Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 15 Dec 2018 18:33:31 -0500 Subject: [PATCH 002/119] asr models related --- fairseq/data/speech_dataset.py | 2 +- fairseq/models/speech_lstm.py | 523 ++++++++++++++++++++++++++++ fairseq/modules/speech_attention.py | 122 +++++++ fairseq/tasks/speech_recognition.py | 49 ++- speech_tools/utils.py | 185 ++++++++++ tests/test_speech_utils.py | 6 +- 6 files changed, 879 insertions(+), 8 deletions(-) create mode 100644 fairseq/models/speech_lstm.py create mode 100644 fairseq/modules/speech_attention.py diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index 06aa78e78..2b1132970 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -128,7 +128,7 @@ def __init__( def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of - their utt_ids. Removes those that only appear in one of them.""" + their utt_ids. Removes those that are only present in one of them.""" assert self.tgt is not None if self.src.utt_ids == self.tgt.utt_ids: return diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py new file mode 100644 index 000000000..37798d338 --- /dev/null +++ b/fairseq/models/speech_lstm.py @@ -0,0 +1,523 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils +from fairseq.modules import AdaptiveSoftmax, speech_attention +from . import ( + FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, + register_model_architecture, +) + +from .lstm import AttentionLayer, Embedding, LSTM, LSTMCell, Linear + +import speech_tools.utils as speech_utils + + +@register_model('speech_lstm') +class SpeechLSTMModel(FairseqModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('feat-dim', type=int, metavar='N', + help='input feature dimension') + paser.add_argument('--encoder-conv-channels', type=str, metavar='STR', + help='list of encoder convolution\'s out channels') + paser.add_argument('--encoder-conv-kernel-size', type=str, metavar='STR', + help='list of encoder convolution\'s kernel size') + paser.add_argument('--encoder-conv-stride', type=str, metavar='STR', + help='list of encoder convolution\'s stride') + parser.add_argument('--encoder-rnn-hidden-size', type=int, metavar='N', + help='encoder rnn\'s hidden size') + parser.add_argument('--encoder-rnn-layers', type=int, metavar='N', + Lhelp='number of rnn encoder layers') + parser.add_argument('--encoder-rnn-bidirectional', action='store_true', + help='make all rnn layers of encoder bidirectional') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-hidden-size', type=int, metavar='N', + help='decoder hidden size') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='number of decoder layers') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--attention-type', type=str, metavar='STR', + choices=['bahdanau','luong'], default='bahdanau', + help='attention type') + parser.add_argument('--attention-dim', type=int, metavar='N', + help='attention dimension') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument('--encoder-rnn-dropout-in', type=float, metavar='D', + help='dropout probability for encoder rnn\'s input') + parser.add_argument('--encoder-rnn-dropout-out', type=float, metavar='D', + help='dropout probability for encoder rnn\'s output') + parser.add_argument('--decoder-dropout-in', type=float, metavar='D', + help='dropout probability for decoder input embedding') + parser.add_argument('--decoder-dropout-out', type=float, metavar='D', + help='dropout probability for decoder output') + parser.add_argument('--share-decoder-input-output-embed', default=False, + action='store_true', + help='share decoder input and output embeddings') + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + # separate decoder input embeddings + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, + task.target_dictionary, + args.decoder_embed_dim + ) + # one last double check of parameter combinations + if args.share_decoder_input_output_embed and ( + args.decoder_embed_dim != args.decoder_out_embed_dim): + raise ValueError( + '--share-decoder-input-output-embeddings requires ' + '--decoder-embed-dim to match --decoder-out-embed-dim' + ) + + out_channels = options.eval_str_list(args.encoder_conv_channels, + type=int) + kernel_size = options.eval_str_list(args.encoder_conv_kernel_size, + type=int) + stride = options.eval_str_list(args.encoder_conv_stride, type=int) + in_channel = 1 # hard-coded for now + conv_layers = ConvBNReLU(out_channels, kernel_size, stride, + in_channel=in_channel) if not out_channels is None else None + + rnn_encoder_input_size = args.feat_dim // in_channel + if conv_layers is not None: + for s in stride: + rnn_encoder_input_size = (rnn_input_size + s[1] - 1) // s[1] + rnn_encoder_input_size *= out_channels[-1] + + encoder = SpeechLSTMEncoder( + conv_layers_before=conv_layers, + input_size=rnn_encoder_input_size, + hidden_size=args.encoder_rnn_hidden_size, + num_layers=args.encoder_rnn_layers, + dropout_in=args.encoder_rnn_dropout_in, + dropout_out=args.encoder_rnn_dropout_out, + bidirectional=args.encoder_rnn_bidirectional, + ) + decoder = SpeechLSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + encoder_output_units=encoder.output_units, + attn_type=args.attention_type, + attn_dim=args.attention_dim, + pretrained_embed=pretrained_decoder_embed, + share_input_output_embed=args.share_decoder_input_output_embed, + adaptive_softmax_cutoff=( + options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == 'adaptive_loss' else None + ), + ) + return cls(encoder, decoder) + + +class ConvBNReLU(nn.Module): + """Sequence of convolution-BatchNorm-ReLU layers.""" + def __init__(self, out_channels, kernel_size, stride, in_channel=1) + super().__init__() + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.in_channel = in_channel + + self.num_layers = len(out_channels) + assert num_layers == len(kernel_size) and num_layers == len(stride) + + self.convolutions = nn.ModuleList() + self.batchnorms = nn.ModuleList() + for i in range(self.num_layers): + self.convolutions.append(Convolution2d( + self.in_channel if i == 0 else self.out_channels[i-1], + self.out_channels[i], + self.kernel_size[i], self.stride[i])) + ) + self.batchnorms.append(nn.BatchNorm2d(out_channels[i])) + + def forward(self, src, src_lengths): + # B X T X C -> B X (input channel num) x T X (C / input channel num) + x = src.view(src.size(0), src.size(1), self.in_channel, + src.size(2) // self.in_channel).transpose(1, 2) + for conv, bn in zip(self.convolutions, self.batchnorms): + x = F.relu(bn(conv(x))) + # B X (output channel num) x T X C' -> B X T X (output channel num) X C' + x = x.transpose(1, 2) + # B X T X (output channel num) X C' -> B X T X C + x = x.contiguous().view(x.size(0), x.size(1), x.size(2) * x.size(3)) + + x_lengths = src_lengths + for i in range(self.num_layers): + x_lengths = (x_lengths + self.stride[0] - 1) // self.stride[0] + padding_mask = 1 - speech_utils.sequence_mask(x_lengths, x.size(1)) + if padding_mask.any(): + x = x.masked_fill(padding_mask.unsqueeze(-1), 0.0) + + return x, x_lengths, padding_mask + + +class SpeechLSTMEncoder(FairseqEncoder): + """LSTM encoder.""" + def __init__( + self, conv_layers_before=None, input_size=40, hidden_size=512, + num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, + left_pad=False, pretrained_embed=None, padding_value=0., + ): + super().__init__(None) # no src dictionary + self.conv_layers_before = conv_layers_before + self.num_layers = num_layers + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.bidirectional = bidirectional + self.hidden_size = hidden_size + + self.lstm = LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=self.dropout_out if num_layers > 1 else 0., + bidirectional=bidirectional, + ) + self.left_pad = left_pad + self.padding_value = padding_value + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + + def forward(self, src_tokens, src_lengths): + if self.left_pad: + # convert left-padding to right-padding + src_tokens = speech_utils.convert_padding_direction( + src_tokens, + src_lengths, + left_to_right=True, + ) + + if self.conv_layers_before is not None: + x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, + src_lengths) + else: + x = src_tokens + + bsz, seqlen = x.size(0), x.size(1) + + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.data.new(*state_size).zero_() + c0 = x.data.new(*state_size).zero_() + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value) + x = F.dropout(x, p=self.dropout_out, training=self.training) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + + def combine_bidir(outs): + return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1) + + final_hiddens = combine_bidir(final_hiddens) + final_cells = combine_bidir(final_cells) + + encoder_padding_mask = padding_mask.t() + + return { + 'encoder_out': (x, final_hiddens, final_cells), + 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None + } + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out['encoder_out'] = tuple( + eo.index_select(1, new_order) + for eo in encoder_out['encoder_out'] + ) + if encoder_out['encoder_padding_mask'] is not None: + encoder_out['encoder_padding_mask'] = \ + encoder_out['encoder_padding_mask'].index_select(1, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return int(1e5) # an arbitrary large number + + +class SpeechLSTMDecoder(FairseqIncrementalDecoder): + """LSTM decoder.""" + def __init__( + self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, + num_layers=1, dropout_in=0.1, dropout_out=0.1, + encoder_output_units=512, attn_type='bahdanau', attn_dim=256, + pretrained_embed=None, share_input_output_embed=False, + adaptive_softmax_cutoff=None, + ): + super().__init__(dictionary) + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.hidden_size = hidden_size + self.share_input_output_embed = share_input_output_embed + self.need_attn = True + + self.adaptive_softmax = None + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.encoder_output_units = encoder_output_units + + self.layers = nn.ModuleList([ + LSTMCell( + input_size=encoder_output_units + (embed_dim if layer == 0 else hidden_size), + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ]) + if attn.type == 'bahdanau': + self.attention = speech_attention.BahdanauAttention(hidden_size, + encoder_output_units, attn_dim) + elif attn.type == 'luong': + self.attention = speech_attention.LuongAttention(hidden_size, + encoder_output_units) + else: + raise ValueError('unrecognized attention type.') + if hidden_size != out_embed_dim: + self.additional_fc = Linear(hidden_size, out_embed_dim) + if adaptive_softmax_cutoff is not None: + # setting adaptive_softmax dropout to dropout_out for now but can be redefined + self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, embed_dim, adaptive_softmax_cutoff, + dropout=dropout_out) + elif not self.share_input_output_embed: + self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) + + def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): + encoder_out = encoder_out_dict['encoder_out'] + encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + bsz, seqlen = prev_output_tokens.size() + + # get outputs from encoder + encoder_outs, _, _ = encoder_out[:3] + srclen = encoder_outs.size(0) + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + x = F.dropout(x, p=self.dropout_in, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental generation) + cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') + if cached_state is not None: + prev_hiddens, prev_cells, input_feed = cached_state + else: + _, encoder_hiddens, encoder_cells = encoder_out[:3] + num_layers = len(self.layers) + prev_hiddens = [x.data.new(bsz, self.hidden_size).zero_() \ + for i in range(num_layers)] + prev_cells = [x.data.new(bsz, self.hidden_size).zero_() \ + for i in range(num_layers)] + input_feed = x.data.new(bsz, self.encoder_output_units).zero_() + + attn_scores = x.data.new(srclen, seqlen, bsz).zero_() + outs = [] + for j in range(seqlen): + # input feeding: concatenate context vector from previous time step + input = torch.cat((x[j, :, :], input_feed), dim=1) + + for i, rnn in enumerate(self.layers): + # recurrent cell + hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) + + # compute and apply attention using the 1st layer's hidden state + if i == 0: + context, attn_scores[:, j, :], _ = self.attention(hidden, + encoder_outs, encoder_padding_mask) + + # hidden state concatenated with context vector becomes the + # input to the next layer + input = torch.cat((hidden, context), dim=1) + input = F.dropout(input, p=self.dropout_out, training=self.training) + + # save state for next time step + prev_hiddens[i] = hidden + prev_cells[i] = cell + + # input feeding + input_feed = context + + # save final output + outs.append(input) + + # cache previous states (no-op except during incremental generation) + utils.set_incremental_state( + self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, -1) + assert x.size(2) == self.hidden_size + self.encoder_output_units + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen + if not self.training and self.need_attn: + attn_scores = attn_scores.transpose(0, 2) + else: + attn_scores = None + + # project back to size of vocabulary + if self.adaptive_softmax is None: + if hasattr(self, 'additional_fc'): + x = self.additional_fc(x) + x = F.dropout(x, p=self.dropout_out, training=self.training) + if self.share_input_output_embed: + x = F.linear(x, self.embed_tokens.weight) + else: + x = self.fc_out(x) + return x, attn_scores + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') + if cached_state is None: + return + + def reorder_state(state): + if isinstance(state, list): + return [reorder_state(state_i) for state_i in state] + return state.index_select(0, new_order) + + new_state = tuple(map(reorder_state, cached_state)) + utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return int(1e5) # an arbitrary large number + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + +def Convolution2d(in_channels, out_channels, kernel_size, stride): + if len(kernel_size) != 2: + if len(kernel_size) == 1: + kernel_size = (kernel_size[0], kernel_size[0]) + else: + assert isinstance(kernel_size, int) + kernel_size = (kernel_size, kernel_size) + if len(stride) != 2: + if len(stride) == 1: + stride = (stride[0], stride[0]) + else: + assert isinstance(stride, int) + stride = (stride, stride) + assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 + padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + m = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \ + padding=padding) + return m + + +@register_model_architecture('speech_lstm', 'speech_lstm') +def base_architecture(args): + args.dropout = getattr(args, 'dropout', 0.1) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) + args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim) + args.encoder_layers = getattr(args, 'encoder_layers', 1) + args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False) + args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout) + args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim) + args.decoder_layers = getattr(args, 'decoder_layers', 1) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) + args.decoder_attention = getattr(args, 'decoder_attention', '1') + args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) + args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) + args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) + args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') + +@register_model_architecture('speech_lstm', 'speech_lstm_wiseman_iwslt_de_en') +def lstm_wiseman_iwslt_de_en(args): + args.dropout = getattr(args, 'dropout', 0.1) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) + args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0) + args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) + args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0) + args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) + base_architecture(args) + + +@register_model_architecture('speech_lstm', 'speech_lstm_luong_wmt_en_de') +def lstm_luong_wmt_en_de(args): + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) + args.encoder_layers = getattr(args, 'encoder_layers', 4) + args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1000) + args.decoder_layers = getattr(args, 'decoder_layers', 4) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1000) + args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', 0) + base_architecture(args) diff --git a/fairseq/modules/speech_attention.py b/fairseq/modules/speech_attention.py new file mode 100644 index 000000000..61d823cf7 --- /dev/null +++ b/fairseq/modules/speech_attention.py @@ -0,0 +1,122 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +import torch +from torch import nn +from torch.nn import Parameter +import torch.nn.functional as F + +from fairseq import utils + + +class BaseAttention(nn.Module): + """Base class for attention layers.""" + + def __init__(self, query_dim, value_dim, embed_dim=None): + super().__init__() + self.query_dim = query_dim + self.value_dim = value_dim + self.embed_dim = embed_dim + + def reset_parameters(self): + pass + + def forward(self, query, value, key_padding_mask=None, state=None): + # query: bsz x q_hidden + # value: len x bsz x v_hidden + # key_padding_mask: len x bsz + raise NotImplementedError + + +class BahdanauAttention(BaseAttention): + """ Bahdanau Attention.""" + + def __init__(self, query_dim, value_dim, embed_dim, normalize=True): + super().__init__(query_dim, value_dim, embed_dim) + self.query_proj = nn.Linear(self.query_dim, self.embed_dim, bias=False) + self.value_proj = nn.Linear(self.value_dim, self.embed_dim, bias=False) + self.v = Parameter(torch.Tensor(embed_dim)) + self.normalize = normalize + if self.normalize: + self.b = Parameter(torch.Tensor(embed_dim)) + self.g = Parameter(torch.Tensor(1)) + + self.reset_parameters() + + def reset_parameters(self): + self.query_proj.weight.data.uniform_(-0.1, 0.1) + self.value_proj.weight.data.uniform_(-0.1, 0.1) + nn.init.uniform_(self.v, -0.1, 0.1) + if self.normalize: + nn.init.constant_(self.b, 0.) + nn.init.constant_(self.g, math.sqrt(1. / embed_dim)) + + def forward(self, query, value, key_padding_mask=None, state=None): + # projected_query: 1 x bsz x embed_dim + projected_query = self.query_proj(query).unsqueeze(0) + key = self.value_proj(value) # len x bsz x embed_dim + if self.normalize: + # normed_v = g * v / ||v|| + normed_v = self.g * self.v / torch.norm(self.v) + attn_scores = (normed_v * nn.tanh(projected_query + key + \ + self.b)).sum(dim=2) # len x bsz + else: + attn_scores = v * nn.tanh(projected_query + key).sum(dim=2) + + if encoder_padding_mask is not None: + attn_scores = attn_scores.float().masked_fill_( + encoder_padding_mask, float('-inf'), + ).type_as(attn_scores) # FP16 support: cast to float and back + + attn_scores = F.softmax(attn_scores, dim=0) # len x bsz + + # sum weighted value. context: bsz x value_dim + context = (attn_scores.unsqueeze(2) * value).sum(dim=0) + next_state = attn_scores + + return context, attn_scores, next_state + + +class LuongAttention(BaseAttention): + """ Luong Attention.""" + + def __init__(self, query_dim, value_dim, embed_dim=None, scale=True): + super().__init__(query_dim, value_dim, embed_dim) + self.value_proj = nn.Linear(self.value_dim, self.query_dim, bias=False) + self.scale = scale + if self.scale: + self.g = Parameter(torch.Tensor(1)) + + self.reset_parameters() + + def reset_parameters(self): + self.value_proj.weight.data.uniform_(-0.1, 0.1) + if self.scale: + nn.init.constant_(self.g, 1.) + + def forward(self, query, value, key_padding_mask=None, state=None): + query = self.query_proj(query).unsqueeze(1) # bsz x 1 x query_dim + key = self.value_proj(value).transpose(0, 1) # bsz x len x query_dim + attn_scores = torch.bmm(query, key.transpose(1, 2)).squeeze(1) + attn_scores = attn_scores.transpose(0, 1) # len x bsz + if self.scale: + attn_scores = self.g * attn_scores + + if encoder_padding_mask is not None: + attn_scores = attn_scores.float().masked_fill_( + encoder_padding_mask, float('-inf'), + ).type_as(attn_scores) # FP16 support: cast to float and back + + attn_scores = F.softmax(attn_scores, dim=0) # len x bsz + + # sum weighted value. context: bsz x value_dim + context = (attn_scores.unsqueeze(2) * value).sum(dim=0) + next_state = attn_scores + + return context, attn_scores, next_state + diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 96858c14d..092c7595f 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -8,8 +8,9 @@ import itertools import numpy as np import os +import re -from fairseq import options +from fairseq import options, utils from fairseq.data import ( data_utils, TokenDictionary, SpeechDataset, ConcatDataset, TokenTextDataset, ScpCachedDataset @@ -42,8 +43,18 @@ class SpeechRecognitionTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('scp_files', nargs='+', help='path(s) to scp file(s)') - parser.add_argument('text_files', nargs='+', help='path(s) to text file(s)') + parser.add_argument('--train-scp-files', nargs='+', + help='path(s) to scp file(s) for training') + parser.add_argument('--train-text-files', nargs='+', + help='path(s) to text file(s) for training') + parser.add_argument('--valid-scp-files', nargs='+', + help='path(s) to scp file(s) for validation') + parser.add_argument('--valid-text-files', nargs='+', + help='path(s) to text file(s) for validation') + parser.add_argument('--test-scp-files', nargs='+', + help='path(s) to scp file(s) for test') + parser.add_argument('--test-text-files', nargs='+', + help='path(s) to text file(s) for test') parser.add_argument('--dict', default=None, type=str, help='path to the dictionary') parser.add_argument('--raw-text', action='store_true', @@ -59,6 +70,20 @@ def add_args(parser): parser.add_argument('--upsample-primary', default=1, type=int, help='amount to upsample primary dataset') + @staticmethod + def load_pretrained_model(path, dict_path, arg_overrides=None): + model = utils.load_checkpoint_to_cpu(path) + args = model['args'] + state_dict = model['model'] + args = utils.override_model_args(args, arg_overrides) + dict = Dictionary.load(dict_path) + + task = SpeechRecognitionTask(args, dict) + model = task.build_model(args) + model.upgrade_state_dict(state_dict) + model.load_state_dict(state_dict, strict=True) + return model + def __init__(self, args, dict): super().__init__(args) self.dict = dict @@ -90,8 +115,22 @@ def load_dataset(self, split, combine=False, **kwargs): src_datasets = [] tgt_datasets = [] - assert len(self.args.scp_files) == len(self.args.text_files) - file_pairs = zip(self.args.scp_files, self.args.text_file) + if split == 'train': + scp_files = self.args.train_scp_files + text_files = self.args.train_text_files + assert len(scp_files) > 0 and len(text_files) > 0 + elif re.match(r"^valid\d*$", split): + scp_files = self.args.valid_scp_files + text_files = self.args.valid_text_files + assert len(scp_files) > 0 and len(text_files) > 0 + elif split == 'test': + scp_files = self.args.test_scp_files + text_files = self.args.test_text_files + assert len(scp_files) > 0 and len(text_files) > 0 + else: + raise ValueError('split should be one of "train", "valid*", "test"') + assert len(scp_files) == len(text_files) + file_pairs = zip(scp_files, text_files) for scp, text in enumerate(file_pairs): assert ScpCachedDataset.exists(scp) and TokenTextDataset.exists(text) src_datasets.append(ScpCachedDataset(scp, ordered_indices=True)) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index ac66f4277..02fcac814 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -6,9 +6,13 @@ # can be found in the PATENTS file in the same directory. import re +import numpy as np +from collections import Counter import torch +from fairseq.utils import buffered_arange + class Tokenizer: @@ -73,3 +77,184 @@ def collate_frames(values, pad_value=0.0, left_pad=False): assert dst.numel() == v.numel() dst.copy_(v) return res + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + else: + assert utils.item(sequence_length.data.max()) <= utils.item(max_len) + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long().to(device=sequence_length.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) + return seq_range_expand < seq_length_expand + +def covnert_padding_direction(src_frames, src_lengths, right_to_left=False, + left_to_right=False): + """Counterpart of :func:`~fairseq.utils.convert_padding_direction`, + operating on 3d tensors of size B x T x C. + """ + assert right_to_left ^ left_to_right + assert src_frames.size(0) == src_lengths.size(0) + pad_mask = sequence_mask(src_lengths, max_len=src_frames.size(1)) + if not pad_mask.any(): + # no padding, return early + return src_frames + if left_to_right and not pad_mask[:, 0].any(): + # already right padded + return src_frames + if right_to_left and not pad_mask[:, -1].any(): + # already left padded + return src_frames + max_len = src_frames.size(1) + range = buffered_arange(max_len).type_as(src_frames).expand_as(src_frames) + num_pads = pad_mask.long().sum(dim=1, keepdim=True) + if right_to_left: + index = torch.remainder(range - num_pads, max_len) + else: + index = torch.remainder(range + num_pads, max_len) + return src_frames.gather(1, index) + +def edit_distance(ref, hyp): + """This function is to calculate the edit distance of reference sentence and + the hypothesis sentence using dynamic programming, and also backtrace to get + a list of edit steps. + + Args: + ref: list of words obtained by splitting reference sentence string + hyp: list of words obtained by splitting hypothesis sentence string + + Return: + dist: edit distance matrix of size len(ref) x len(hyp) + steps: list of edit steps + counter: object of collections.Counter containing counts of + reference words ('words'), number of correct words ('corr'), + substitutions ('sub'), insertions ('ins'), deletions ('del'). + + + """ + + assert isinstance(ref, list) and isinstance(hyp, list) + + dist = numpy.zeros((len(ref) + 1, len(hyp) + 1), dtype=numpy.uint32) + for i in range(len(ref) + 1): + for j in range(len(hyp) + 1): + if i == 0: + d[0][j] = j + elif j == 0: + d[i][0] = i + for i in range(1, len(ref) + 1): + for j in range(1, len(hyp) + 1): + if ref[i - 1] == hyp[j - 1]: + dist[i][j] = dist[i - 1][j - 1] + else: + substitute = dist[i - 1][j - 1] + 1 + insert = dist[i][j - 1] + 1 + delete = dist[i - 1][j] + 1 + dist[i][j] = min(substitute, insert, delete) + + i = len(ref) + j = len(hyp) + steps = [] + while True: + if i == 0 and j == 0: + break + elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] + assert ref[i - 1] == hyp[j - 1] + steps.append('corr') + i = i - 1 + j = j - 1 + elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] + 1: + steps.append('sub') + i = i - 1 + j = j - 1 + elif j >= 1 and dist[i][j] == dist[i][j - 1] + 1: + steps.append('ins') + j = j - 1 + else: + assert i >= 1 and dist[i][j] == dist[i - 1][j] + 1 + steps.append('del') + i = i - 1 + steps = steps[::-1] + + counter = Counter({'words', len(ref)}) + counter.update(steps) + + return dist, steps, counter + +def aligned_print(ref, hyp, steps): + """This funcition is to print the result of comparing reference and + hypothesis sentences in an aligned way. + + Args: + ref: list of words obtained by splitting reference sentence string + hyp: list of words obtained by splitting hypothesis sentence string + steps: list of edit steps with elements 'corr', 'sub', 'ins' or 'del'. + """ + + assert isinstance(ref, list) and isinstance(hyp, list) + assert isinstance(steps, list) + + out_str = 'REF: ' + for i in range(len(steps)): + delim = ' ' if i < len(steps) - 1 else '\n' + if steps[i] == 'sub': + ref_idx = i - steps[:i].count('ins') + hyp_idx = i - steps[:i].count('del') + if len(ref[ref_idx]) < len(hyp[hyp_idx]): + out_str += ref[ref_idx] + + ' ' * (len(hyp[hyp_idx])-len(ref[ref_idx])) + delim + else: + out_str += ref[ref_idx] + delim + elif steps[i] == 'ins': + idx = i - steps[:i].count('del') + out_str += ' ' * len(hyp[idx] + delim + else: + assert steps[i] == 'del' or steps[i] == 'corr' + idx = i - steps[:i].count('ins') + out_str += ref[idx] + delim + + out_str += 'HYP: ' + for i in range(len(steps)): + delim = ' ' if i < len(steps) - 1 else '\n' + if steps[i] == 'sub': + ref_idx = i - steps[:i].count('ins') + hyp_idx = i - steps[:i].count('del') + if len(ref[ref_idx]) > len(hyp[hyp_idx]): + out_str += hyp[hyp_idx] + + ' ' * (len(ref[ref_idx])-len(hyp[hyp_idx])) + delim + else: + out_str += hyp[hyp_idx] + delim + elif steps[i] == 'del': + idx = i - steps[:i].count('ins') + out_str += ' ' * len(ref[idx] + delim + else: + assert steps[i] == 'ins' or steps[i] == 'corr' + idx = i - steps[:i].count('del') + out_str += hyp[idx] + delim + + out_str += 'STP: ' + for i in range(len(steps)): + delim = ' ' if i < len(steps) - 1 else '\n' + if steps[i] == 'sub': + ref_idx = i - steps[:i].count('ins') + hyp_idx = i - steps[:i].count('del') + if len(ref[ref_idx]) > len(hyp[hyp_idx]): + out_str += 'S' + ' ' * (len(ref[ref_idx]) - 1) + delim + else: + out_str += 'S' + ' ' * (len(hyp[hyp_idx]) - 1) + delim + elif steps[i] == 'ins': + idx = i - steps[:i].count('del') + out_str += 'I' + ' ' * (len(hyp[idx]) - 1) + delim + else: + assert steps[i] == 'del' or steps[i] == 'corr' + idx = i - steps[:i].count('ins') + sym = 'D' if step[i] == 'del' else ' ' + out_str += sym + ' ' * (len(ref[idx]) - 1) + delim + + counter = Counter(steps) + wer = float(counter['sub'] + counter['ins'] + counter['del']) / len(ref) \ + * 100 + out_str += 'WER: ' + '{:.2f}%'.format(wer) + '\n' + + return out_str diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py index d42c4bcbf..2d55cbcb4 100644 --- a/tests/test_speech_utils.py +++ b/tests/test_speech_utils.py @@ -73,7 +73,8 @@ def test_speech_tokenizer(self): tokens = utils.Tokenizer.tokenize(sent, \ space=self.dict.space_word, non_lang_syms=self.non_lang_syms) - # test Tokenizer.tokenize() with Tokenizer.tokens_to_index_tensor() + # test :func:`~speech_tools.utils.Tokenizer.tokenize` with + # :func:`~speech_tools.utils.Tokenizer.tokens_to_index_tensor` tensor = utils.Tokenizer.tokens_to_index_tensor(tokens, self.dict, \ append_eos=True) reconstructed_tokens = self.dict.string(tensor) @@ -83,7 +84,8 @@ def test_speech_tokenizer(self): ) self.assertEqual(reconstructed_tokens, expected_tokens) - # test Tokenizer.tokenize() with Tokenizer.tokens_to_sentence() + # test :func:`~speech_tools.utils.Tokenizer.tokenize` with + # :func:`~speech_tools.utils.Tokenizer.tokens_to_sentence` reconstructed_sent = utils.Tokenizer.tokens_to_sentence(tokens, self.dict) expected_sent = [] From d6033af399374a365924a8a8f695e79ec84790b7 Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 17 Dec 2018 15:01:43 -0500 Subject: [PATCH 003/119] decoding related --- fairseq/criterions/cross_entropy_with_wer.py | 76 ++++ fairseq/models/speech_lstm.py | 20 +- fairseq/speech_recognizer.py | 60 +++ fairseq/wer.py | 94 +++++ speech_recognition.py | 169 +++++++++ speech_tools/utils.py | 66 ++-- speech_train.py | 365 +++++++++++++++++++ tests/test_speech_utils.py | 108 +++++- 8 files changed, 913 insertions(+), 45 deletions(-) create mode 100644 fairseq/criterions/cross_entropy_with_wer.py create mode 100644 fairseq/speech_recognizer.py create mode 100644 fairseq/wer.py create mode 100644 speech_recognition.py create mode 100644 speech_train.py diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py new file mode 100644 index 000000000..79626fa79 --- /dev/null +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -0,0 +1,76 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +import torch.nn.functional as F + +from fairseq import utils, wer + +from . import FairseqCriterion, register_criterion +from .cross_entropy import CrossEntropyCriterion + + +@register_criterion('cross_entropy_with_wer') +class CrossEntropyWithWERCriterion(CrossEntropyCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + dict = self.task.dict if hasattr(self.task, 'dict') \ + else self.task.tgt_dict + self.scorer = wer.Scorer(dict) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample['net_input']) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + # wer code starts + if not model.training: + pred = lprobs.argmax(-1).int().cpu() # bsz x len + assert pred.size() == sample['net_input']['prev_output_tokens'].size() + assert pred.size() == sample['target'].size() + dict = self.task.dict if hasattr(self.task, 'dict') \ + else self.task.tgt_dict + self.scorer.reset() + ref_str_list = dict.string(sample['target'].int().cpu()).split('\n') + pred_str_list = dict.string(pred).split('\n') + for ref_str, pred_str in zip(ref_str_list, pred_str_list): + scorer.add(ref_str, pred_str) + # wer code ends + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = model.get_targets(sample, net_output).view(-1) + loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, + reduce=reduce) + sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), + 'sample_size': sample_size, + } + if not model.training: + logging_output['word_error'] = scorer.acc_word_error() + logging_output['word_count'] = scorer.acc_word_count() + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + agg_output = super().aggregate_logging_outputs(logging_outputs) + word_error = sum(log.get('word_error', 0) for log in logging_outputs) + word_count = sum(log.get('word_count', 0) for log in logging_outputs) + if word_count > 0: + agg_output['word_error'] = word_error + agg_output['word_count'] = word_count + else: + print('Not aggregating WER in training mode.') + return agg_output diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 37798d338..3c2b41fd7 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -33,16 +33,16 @@ def add_args(parser): help='dropout probability') parser.add_argument('feat-dim', type=int, metavar='N', help='input feature dimension') - paser.add_argument('--encoder-conv-channels', type=str, metavar='STR', + parser.add_argument('--encoder-conv-channels', type=str, metavar='STR', help='list of encoder convolution\'s out channels') - paser.add_argument('--encoder-conv-kernel-size', type=str, metavar='STR', + parser.add_argument('--encoder-conv-kernel-size', type=str, metavar='STR', help='list of encoder convolution\'s kernel size') - paser.add_argument('--encoder-conv-stride', type=str, metavar='STR', + parser.add_argument('--encoder-conv-stride', type=str, metavar='STR', help='list of encoder convolution\'s stride') parser.add_argument('--encoder-rnn-hidden-size', type=int, metavar='N', help='encoder rnn\'s hidden size') parser.add_argument('--encoder-rnn-layers', type=int, metavar='N', - Lhelp='number of rnn encoder layers') + help='number of rnn encoder layers') parser.add_argument('--encoder-rnn-bidirectional', action='store_true', help='make all rnn layers of encoder bidirectional') parser.add_argument('--decoder-embed-dim', type=int, metavar='N', @@ -154,7 +154,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): class ConvBNReLU(nn.Module): """Sequence of convolution-BatchNorm-ReLU layers.""" - def __init__(self, out_channels, kernel_size, stride, in_channel=1) + def __init__(self, out_channels, kernel_size, stride, in_channel=1): super().__init__() self.out_channels = out_channels self.kernel_size = kernel_size @@ -167,11 +167,11 @@ def __init__(self, out_channels, kernel_size, stride, in_channel=1) self.convolutions = nn.ModuleList() self.batchnorms = nn.ModuleList() for i in range(self.num_layers): - self.convolutions.append(Convolution2d( - self.in_channel if i == 0 else self.out_channels[i-1], - self.out_channels[i], - self.kernel_size[i], self.stride[i])) - ) + self.convolutions.append( + Convolution2d( + self.in_channel if i == 0 else self.out_channels[i-1], + self.out_channels[i], + self.kernel_size[i], self.stride[i])) self.batchnorms.append(nn.BatchNorm2d(out_channels[i])) def forward(self, src, src_lengths): diff --git a/fairseq/speech_recognizer.py b/fairseq/speech_recognizer.py new file mode 100644 index 000000000..97905d142 --- /dev/null +++ b/fairseq/speech_recognizer.py @@ -0,0 +1,60 @@ +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math + +import torch + +from fairseq import utils + +from fairseq.sequence_generator import SequenceGenerator + + +class SpeechRecognizer(SequenceGenerator): + def generate_batched_itr( + self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, + cuda=False, timer=None, prefix_size=0, + ): + """Iterate over a batched dataset and yield individual transcription. + Args: + maxlen_a/b: generate sequences of maximum length ax + b, + where x is the source sentence length. + cuda: use GPU for generation + timer: StopwatchMeter for timing generations. + """ + if maxlen_b is None: + maxlen_b = self.maxlen + + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if 'net_input' not in s: + continue + input = s['net_input'] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() + if k != 'prev_output_tokens' + } + srclen = encoder_input['src_tokens'].size(1) + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate( + encoder_input, + beam_size=beam_size, + maxlen=int(maxlen_a*srclen + maxlen_b), + prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, + ) + if timer is not None: + timer.stop(sum(len(h[0]['tokens']) for h in hypos)) + for i, id in enumerate(s['id'].data): + utt_id = s['utt_id'][i] + # remove padding + ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None + yield id, utt_id, ref, hypos[i] + diff --git a/fairseq/wer.py b/fairseq/wer.py new file mode 100644 index 000000000..00362cf68 --- /dev/null +++ b/fairseq/wer.py @@ -0,0 +1,94 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from collections import Counter + +import speech_tools.utils as speech_utils + + +class Scorer(object): + def __init__(self, dict): + self.dict = dict + self.reset() + + def reset(self): + self.char_counter = Counter() + self.word_counter = Counter() + self.results = '' + self.aligned_results = '' + + def add_prediction(self, pred, utt_id=None): + if not isinstance(pred, str): + raise TypeError('pred must be a string(got {})'.format(type(pred))) + if utt_id is not None and not isinstance(utt_id, str): + raise TypeError('utt_id must be a string(got {}) if not None' + .format(type(utt_id))) + + pred_words= speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) + if utt_id is not None: + self.results += utt_id + '\n' + self.results += pred_words + '\n' + + def add(self, ref, pred, utt_id=None): + if not isinstance(ref, str): + raise TypeError('ref must be a string (got {})'.format(type(ref))) + if not isinstance(pred, str): + raise TypeError('pred must be a string(got {})'.format(type(pred))) + if utt_id is not None and not isinstance(utt_id, str): + raise TypeError('utt_id must be a string(got {}) if not None' + .format(type(utt_id))) + + # char level counts + _, _, counter = speech_utils.edit_distance(ref.strip().split(), + pred.strip().split()) + self.char_counter += counter + + # word level counts + ref_words = speech_utils.Tokenizer.tokens_to_sentence(ref, self.dict) + pred_words= speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) + ref_word_list, pred_word_list = ref_words.split(), pred_words.split() + _, steps, counter = speech_utils.edit_distance(ref_word_list, + pred_word_list) + self.word_counter += counter + if utt_id is not None: + self.aligned_results += utt_id + '\n' + self.aligned_results += speech_utils.aligned_print(ref_word_list, + pred_word_list, steps) + + def cer(self): + assert self.char_counter['words'] > 0 + cer = float(self.char_counter['sub'] + self.char_counter['ins'] + \ + self.char_counter['del']) / self.char_counter['words'] * 100 + sub = float(self.char_counter['sub']) / self.char_counter['words'] * 100 + ins = float(self.char_counter['ins']) / self.char_counter['words'] * 100 + dlt = float(self.char_counter['del']) / self.char_counter['words'] * 100 + return cer, sub, ins, dlt + + def wer(self): + assert self.word_counter['words'] > 0 + wer = float(self.word_counter['sub'] + self.word_counter['ins'] + \ + self.word_counter['del']) / self.word_counter['words'] * 100 + sub = float(self.word_counter['sub']) / self.word_counter['words'] * 100 + ins = float(self.word_counter['ins']) / self.word_counter['words'] * 100 + dlt = float(self.word_counter['del']) / self.word_counter['words'] * 100 + return wer, sub, ins, dlt + + def acc_word_error(self): + return self.word_counter['sub'] + self.word_counter['ins'] + \ + self.word_counter['del'] + + def acc_word_count(self): + return self.word_counter['words'] + + @property + def results(self): + return self.results + + @property + def aligned_results(self): + return self.aligned_results + diff --git a/speech_recognition.py b/speech_recognition.py new file mode 100644 index 000000000..ad7ce6ff9 --- /dev/null +++ b/speech_recognition.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 -u +# Copyright (c) 2018-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. +""" +Recognize pre-processed speech with a trained model. +""" + +import os + +import torch + +from fairseq import wer, data, options, progress_bar, tasks, utils +from fairseq.meters import StopwatchMeter +from fairseq.speech_recognizer import SpeechRecognizer + + +def main(args): + assert args.path is not None, '--path required for recognition!' + assert not args.sampling or args.nbest == args.beam, \ + '--sampling requires --nbest to be equal to --beam' + + if args.max_tokens is None and args.max_sentences is None: + args.max_tokens = 12000 + print(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset split + task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) + print('| {} {} {} examples'.format(args.data, args.gen_subset, + len(task.dataset(args.gen_subset)))) + + # Set dictionary + dict = task.target_dictionary + + # Load ensemble + print('| loading model(s) from {}'.format(args.path)) + models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, + model_arg_overrides=eval(args.model_overrides)) + + # Optimize ensemble for generation + for model in models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if args.fp16: + model.half() + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[model.max_positions() for model in models] + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=8, + num_shards=args.num_shards, + shard_id=args.shard_id, + ).next_epoch_itr(shuffle=False) + + # Initialize generator + gen_timer = StopwatchMeter() + recognizer = SpeechSequenceGenerator( + models, dict, beam_size=args.beam, minlen=args.min_len, + stop_early=(not args.no_early_stop), + normalize_scores=(not args.unnormalized), + len_penalty=args.lenpen, unk_penalty=args.unkpen, + sampling=args.sampling, sampling_topk=args.sampling_topk, + sampling_temperature=args.sampling_temperature, + diverse_beam_groups=args.diverse_beam_groups, + diverse_beam_strength=args.diverse_beam_strength, + ) + + if use_cuda: + recognizer.cuda() + + # Generate and compute WER + scorer = wer.Scorer(dict) + num_sentences = 0 + has_target = True + with progress_bar.build_progress_bar(args, itr) as t: + recognitions = recognizer.generate_batched_itr( + t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, + cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size, + ) + + sps_meter = TimeMeter() + for sample_id, utt_id, target_tokens, hypos in recognitions: + # Process input and ground truth + has_target = target_tokens is not None + target_tokens = target_tokens.int().cpu() if has_target else None + + # Regenerate original sentences from tokens. + if has_target: + target_str = dict.string(target_tokens, args.remove_bpe) + if not args.quiet: + print('T-{}\t{}'.format(utt_id, target_str)) + + # Process top predictions + for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): + hypo_str = dict.string(hypo['tokens'].int().cpu(), remove_bpe) + + if not args.quiet: + print('H-{}\t{}\t{}'.format(utt_id, hypo['score'], hypo_str)) + print('P-{}\t{}'.format( + utt_id, + ' '.join(map( + lambda x: '{:.4f}'.format(x), + hypo['positional_scores'].tolist(), + )) + )) + + # Score and obtain attention only the top hypothesis + if i == 0: + # src_len x tgt_len + attention = hypo['attention'].float().cpu() \ + if hypo['attention'] is not None else None + scorer.add_prediction(hypo_str, utt_id=utt_id) + if has_target: + scorer.add(target_str, hypo_str, utt_id=utt_id) + + num_sentences += 1 + + print('| Recognized {} utterances in {:.1f}s ({:.2f} utterances/s)'.format( + num_sentences, gen_timer.sum, 1. / gen_timer.avg)) + + fn = 'results.txt' + with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + f.write(scorer.results) + print('| Decoded results saved as ' + f.name) + + if has_target: + print('| Recognize {} with beam={}: WER={:.2f}%, Sub={:.2f}%, ' + 'Ins={:.2f}%, Del={:.2f}%'.format(args.gen_subset, args.beam, + *(scorer.wer()))) + print('| CER={:.2f}%, Sub={:.2f}%, ' + 'Ins={:.2f}%, Del={:.2f}%'.format(*(scorer.cer()))) + + fn = 'wer.txt' + with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + f.write('WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%\n' + .format(*(scorer.wer()))) + print('| WER saved in ' + f.name) + + fn = 'cer.txt' + with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + f.write('CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%\n' + .format(*(scorer.cer()))) + print('| CER saved in ' + f.name) + + fn = 'aligned_results.txt' + with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + f.write(scorer.aligned_results) + print('| Aligned results saved as ' + f.name) + + +if __name__ == '__main__': + parser = options.get_generation_parser() + args = options.parse_args_and_arch(parser) + main(args) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 02fcac814..12c6c2d17 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -11,7 +11,7 @@ import torch -from fairseq.utils import buffered_arange +from fairseq.utils import buffered_arange, item class Tokenizer: @@ -82,33 +82,30 @@ def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() else: - assert utils.item(sequence_length.data.max()) <= utils.item(max_len) + assert item(sequence_length.data.max()) <= item(max_len) batch_size = sequence_length.size(0) - seq_range = torch.arange(0, max_len).long().to(device=sequence_length.device) + seq_range = torch.arange(0, max_len).to(device=sequence_length.device, + dtype=sequence_length.dtype) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) return seq_range_expand < seq_length_expand -def covnert_padding_direction(src_frames, src_lengths, right_to_left=False, +def convert_padding_direction(src_frames, src_lengths, right_to_left=False, left_to_right=False): """Counterpart of :func:`~fairseq.utils.convert_padding_direction`, - operating on 3d tensors of size B x T x C. + operating on 3d tensors of size B x T x C. Note that this function is unware + of whether it has already been right padded or left padded (since any real + value is legal for non-padded elements), so be clear of the actual padding + direction before calling this function. """ assert right_to_left ^ left_to_right assert src_frames.size(0) == src_lengths.size(0) - pad_mask = sequence_mask(src_lengths, max_len=src_frames.size(1)) - if not pad_mask.any(): + max_len = src_frames.size(1) + if not src_lengths.eq(max_len).any(): # no padding, return early return src_frames - if left_to_right and not pad_mask[:, 0].any(): - # already right padded - return src_frames - if right_to_left and not pad_mask[:, -1].any(): - # already left padded - return src_frames - max_len = src_frames.size(1) - range = buffered_arange(max_len).type_as(src_frames).expand_as(src_frames) - num_pads = pad_mask.long().sum(dim=1, keepdim=True) + range = buffered_arange(max_len).unsqueeze(-1).expand_as(src_frames) + num_pads = (max_len - src_lengths.type_as(range)).unsqueeze(-1).unsqueeze(-1) if right_to_left: index = torch.remainder(range - num_pads, max_len) else: @@ -130,19 +127,17 @@ def edit_distance(ref, hyp): counter: object of collections.Counter containing counts of reference words ('words'), number of correct words ('corr'), substitutions ('sub'), insertions ('ins'), deletions ('del'). - - """ assert isinstance(ref, list) and isinstance(hyp, list) - dist = numpy.zeros((len(ref) + 1, len(hyp) + 1), dtype=numpy.uint32) + dist = np.zeros((len(ref) + 1, len(hyp) + 1), dtype=np.uint32) for i in range(len(ref) + 1): for j in range(len(hyp) + 1): if i == 0: - d[0][j] = j + dist[0][j] = j elif j == 0: - d[i][0] = i + dist[i][0] = i for i in range(1, len(ref) + 1): for j in range(1, len(hyp) + 1): if ref[i - 1] == hyp[j - 1]: @@ -159,15 +154,14 @@ def edit_distance(ref, hyp): while True: if i == 0 and j == 0: break - elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] - assert ref[i - 1] == hyp[j - 1] + elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] and \ + ref[i - 1] == hyp[j - 1]: steps.append('corr') - i = i - 1 - j = j - 1 + i, j = i - 1, j - 1 elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] + 1: + assert ref[i - 1] != hyp[j - 1] steps.append('sub') - i = i - 1 - j = j - 1 + i, j = i - 1, j - 1 elif j >= 1 and dist[i][j] == dist[i][j - 1] + 1: steps.append('ins') j = j - 1 @@ -177,7 +171,8 @@ def edit_distance(ref, hyp): i = i - 1 steps = steps[::-1] - counter = Counter({'words', len(ref)}) + counter = Counter({'words': len(ref), 'corr': 0, 'sub': 0, 'ins': 0, + 'del': 0}) counter.update(steps) return dist, steps, counter @@ -190,6 +185,9 @@ def aligned_print(ref, hyp, steps): ref: list of words obtained by splitting reference sentence string hyp: list of words obtained by splitting hypothesis sentence string steps: list of edit steps with elements 'corr', 'sub', 'ins' or 'del'. + + Return: + out_str: aligned reference and hypothesis string with edit steps. """ assert isinstance(ref, list) and isinstance(hyp, list) @@ -202,13 +200,13 @@ def aligned_print(ref, hyp, steps): ref_idx = i - steps[:i].count('ins') hyp_idx = i - steps[:i].count('del') if len(ref[ref_idx]) < len(hyp[hyp_idx]): - out_str += ref[ref_idx] + - ' ' * (len(hyp[hyp_idx])-len(ref[ref_idx])) + delim + out_str += ref[ref_idx] + \ + ' ' * (len(hyp[hyp_idx]) - len(ref[ref_idx])) + delim else: out_str += ref[ref_idx] + delim elif steps[i] == 'ins': idx = i - steps[:i].count('del') - out_str += ' ' * len(hyp[idx] + delim + out_str += ' ' * len(hyp[idx]) + delim else: assert steps[i] == 'del' or steps[i] == 'corr' idx = i - steps[:i].count('ins') @@ -221,13 +219,13 @@ def aligned_print(ref, hyp, steps): ref_idx = i - steps[:i].count('ins') hyp_idx = i - steps[:i].count('del') if len(ref[ref_idx]) > len(hyp[hyp_idx]): - out_str += hyp[hyp_idx] + - ' ' * (len(ref[ref_idx])-len(hyp[hyp_idx])) + delim + out_str += hyp[hyp_idx] + \ + ' ' * (len(ref[ref_idx]) - len(hyp[hyp_idx])) + delim else: out_str += hyp[hyp_idx] + delim elif steps[i] == 'del': idx = i - steps[:i].count('ins') - out_str += ' ' * len(ref[idx] + delim + out_str += ' ' * len(ref[idx]) + delim else: assert steps[i] == 'ins' or steps[i] == 'corr' idx = i - steps[:i].count('del') diff --git a/speech_train.py b/speech_train.py new file mode 100644 index 000000000..d361399c6 --- /dev/null +++ b/speech_train.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python3 -u +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. +""" +Train a new model on one or across multiple GPUs. +""" + +import collections +import itertools +import os +import math +import torch + +from fairseq import distributed_utils, options, progress_bar, tasks, utils +from fairseq.data import iterators +from fairseq.trainer import Trainer +from fairseq.meters import AverageMeter, StopwatchMeter + + +def main(args): + if args.max_tokens is None: + args.max_tokens = 6000 + print(args) + + if not torch.cuda.is_available(): + raise NotImplementedError('Training on CPU is not supported') + torch.cuda.set_device(args.device_id) + torch.manual_seed(args.seed) + + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(args) + + # Load dataset splits + load_dataset_splits(task, ['train', 'valid']) + + # Build model and criterion + model = task.build_model(args) + criterion = task.build_criterion(args) + print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) + print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters()))) + + # Make a dummy batch to (i) warm the caching allocator and (ii) as a + # placeholder DistributedDataParallel when there's an uneven number of + # batches per worker. + max_positions = utils.resolve_max_positions( + task.max_positions(), + model.max_positions(), + ) + dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) + oom_batch = task.dataset('train').get_dummy_batch(1, max_positions) + + # Build trainer + trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) + print('| training on {} GPUs'.format(args.distributed_world_size)) + print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( + args.max_tokens, + args.max_sentences, + )) + + # Initialize dataloader + epoch_itr = task.get_batch_iterator( + dataset=task.dataset(args.train_subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=True, + required_batch_size_multiple=8, + seed=args.seed, + num_shards=args.distributed_world_size, + shard_id=args.distributed_rank, + ) + + # Load the latest checkpoint if one is available + if not load_checkpoint(args, trainer, epoch_itr): + trainer.dummy_train_step([dummy_batch]) + + # Train until the learning rate gets too small + max_epoch = args.max_epoch or math.inf + max_update = args.max_update or math.inf + lr = trainer.get_lr() + train_meter = StopwatchMeter() + train_meter.start() + valid_losses = [None] + valid_subsets = args.valid_subset.split(',') + while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update: + # train for one epoch + train(args, trainer, task, epoch_itr) + + if epoch_itr.epoch % args.validate_interval == 0: + valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, valid_subsets) + + # only use first validation loss to update the learning rate + lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) + + # save checkpoint + if epoch_itr.epoch % args.save_interval == 0: + save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) + train_meter.stop() + print('| done training in {:.1f} seconds'.format(train_meter.sum)) + + +def train(args, trainer, task, epoch_itr): + """Train the model for one epoch.""" + + # Update parameters every N batches + if epoch_itr.epoch <= len(args.update_freq): + update_freq = args.update_freq[epoch_itr.epoch - 1] + else: + update_freq = args.update_freq[-1] + + # Initialize data iterator + itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus) + itr = iterators.GroupedIterator(itr, update_freq) + progress = progress_bar.build_progress_bar( + args, itr, epoch_itr.epoch, no_progress_bar='simple', + ) + + extra_meters = collections.defaultdict(lambda: AverageMeter()) + first_valid = args.valid_subset.split(',')[0] + max_update = args.max_update or math.inf + for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): + log_output = trainer.train_step(samples) + if log_output is None: + continue + + # log mid-epoch stats + stats = get_training_stats(trainer) + for k, v in log_output.items(): + if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: + continue # these are already logged above + if 'loss' in k: + extra_meters[k].update(v, log_output['sample_size']) + else: + extra_meters[k].update(v) + stats[k] = extra_meters[k].avg + progress.log(stats) + + # ignore the first mini-batch in words-per-second calculation + if i == 0: + trainer.get_meter('wps').reset() + + num_updates = trainer.get_num_updates() + if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: + valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, [first_valid]) + save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) + + if num_updates >= max_update: + break + + # log end-of-epoch stats + stats = get_training_stats(trainer) + for k, meter in extra_meters.items(): + stats[k] = meter.avg + progress.print(stats) + + # reset training meters + for k in [ + 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', + ]: + meter = trainer.get_meter(k) + if meter is not None: + meter.reset() + + +def get_training_stats(trainer): + stats = collections.OrderedDict() + stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg) + if trainer.get_meter('train_nll_loss').count > 0: + nll_loss = trainer.get_meter('train_nll_loss').avg + stats['nll_loss'] = '{:.3f}'.format(nll_loss) + else: + nll_loss = trainer.get_meter('train_loss').avg + stats['ppl'] = get_perplexity(nll_loss) + stats['wps'] = round(trainer.get_meter('wps').avg) + stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg) + stats['wpb'] = round(trainer.get_meter('wpb').avg) + stats['bsz'] = round(trainer.get_meter('bsz').avg) + stats['num_updates'] = trainer.get_num_updates() + stats['lr'] = trainer.get_lr() + stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg) + stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg) + stats['oom'] = trainer.get_meter('oom').avg + if trainer.get_meter('loss_scale') is not None: + stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg) + stats['wall'] = round(trainer.get_meter('wall').elapsed_time) + stats['train_wall'] = round(trainer.get_meter('train_wall').sum) + return stats + + +def validate(args, trainer, task, epoch_itr, subsets): + """Evaluate the model on the validation set(s) and return the losses.""" + valid_losses = [] + valid_wers = [] + for subset in subsets: + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=task.dataset(subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences_valid, + max_positions=utils.resolve_max_positions( + task.max_positions(), + trainer.get_model().max_positions(), + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=8, + seed=args.seed, + num_shards=args.distributed_world_size, + shard_id=args.distributed_rank, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.build_progress_bar( + args, itr, epoch_itr.epoch, + prefix='valid on \'{}\' subset'.format(subset), + no_progress_bar='simple' + ) + + # reset validation loss meters + for k in ['valid_loss', 'valid_nll_loss']: + meter = trainer.get_meter(k) + if meter is not None: + meter.reset() + extra_meters = collections.defaultdict(lambda: AverageMeter()) + + for sample in progress: + log_output = trainer.valid_step(sample) + + for k, v in log_output.items(): + if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', + 'sample_size', 'word_count']: + continue + if k == 'word_error': + extra_meters['valid_wer'].update( + v / log_output['word_count'], log_output['word_count']) + else: + extra_meters[k].update(v) + + # log validation stats + stats = get_valid_stats(trainer) + for k, meter in extra_meters.items(): + stats[k] = meter.avg + progress.print(stats) + + valid_losses.append(stats['valid_loss']) + valid_wers.append(stats['valid_wer']) + return valid_losses, valid_wers + + +def get_valid_stats(trainer): + stats = collections.OrderedDict() + stats['valid_loss'] = trainer.get_meter('valid_loss').avg + if trainer.get_meter('valid_nll_loss').count > 0: + nll_loss = trainer.get_meter('valid_nll_loss').avg + stats['valid_nll_loss'] = nll_loss + else: + nll_loss = trainer.get_meter('valid_loss').avg + stats['valid_ppl'] = get_perplexity(nll_loss) + stats['num_updates'] = trainer.get_num_updates() + return stats + + +def get_perplexity(loss): + try: + return '{:.2f}'.format(math.pow(2, loss)) + except OverflowError: + return float('inf') + + +def save_checkpoint(args, trainer, epoch_itr, val_wer): + if args.no_save or not distributed_utils.is_master(args): + return + epoch = epoch_itr.epoch + end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + checkpoint_conds = collections.OrderedDict() + checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( + end_of_epoch and not args.no_epoch_checkpoints and + epoch % args.save_interval == 0 + ) + checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( + not end_of_epoch and args.save_interval_updates > 0 and + updates % args.save_interval_updates == 0 + ) + checkpoint_conds['checkpoint_best.pt'] = ( + val_wer is not None and + (not hasattr(save_checkpoint, 'best') or val_wer < save_checkpoint.best) + ) + checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink + + prev_best = getattr(save_checkpoint, 'best', val_wer) + if val_wer is not None: + save_checkpoint.best = min(val_wer, prev_best) + extra_state = { + 'train_iterator': epoch_itr.state_dict(), + 'val_wer': val_wer, + } + if hasattr(save_checkpoint, 'best'): + extra_state.update({'best': save_checkpoint.best}) + + checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] + if len(checkpoints) > 0: + for cp in checkpoints: + trainer.save_checkpoint(cp, extra_state) + + if not end_of_epoch and args.keep_interval_updates > 0: + # remove old checkpoints; checkpoints are sorted in descending order + checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') + for old_chk in checkpoints[args.keep_interval_updates:]: + os.remove(old_chk) + + +def load_checkpoint(args, trainer, epoch_itr): + """Load a checkpoint and replay dataloader to match.""" + os.makedirs(args.save_dir, exist_ok=True) + checkpoint_path = os.path.join(args.save_dir, args.restore_file) + if os.path.isfile(checkpoint_path): + extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler, + eval(args.optimizer_overrides)) + if extra_state is not None: + # replay train iterator to match checkpoint + epoch_itr.load_state_dict(extra_state['train_iterator']) + + print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( + checkpoint_path, epoch_itr.epoch, trainer.get_num_updates())) + + trainer.lr_step(epoch_itr.epoch) + trainer.lr_step_update(trainer.get_num_updates()) + if 'best' in extra_state: + save_checkpoint.best = extra_state['best'] + return True + return False + + +def load_dataset_splits(task, splits): + for split in splits: + if split == 'train': + task.load_dataset(split, combine=True) + else: + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else '') + try: + task.load_dataset(split_k, combine=False) + except FileNotFoundError as e: + if k > 0: + break + raise e + + +if __name__ == '__main__': + parser = options.get_training_parser() + args = options.parse_args_and_arch(parser) + + if args.distributed_port > 0 or args.distributed_init_method is not None: + from distributed_train import main as distributed_main + + distributed_main(args) + elif args.distributed_world_size > 1: + from multiprocessing_train import main as multiprocessing_main + + multiprocessing_main(args) + else: + main(args) diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py index 2d55cbcb4..a13756f1d 100644 --- a/tests/test_speech_utils.py +++ b/tests/test_speech_utils.py @@ -8,7 +8,7 @@ import unittest import string import numpy as np -import os +from collections import Counter import torch @@ -101,6 +101,112 @@ def test_speech_tokenizer(self): expected_sent = ' '.join(expected_sent) self.assertEqual(reconstructed_sent, expected_sent) + def test_collate_frames(self): + vals = [ + torch.tensor([4.5, 2.3, 1.2]).unsqueeze(-1).expand(-1, 10), + torch.tensor([6.7, 9.8]).unsqueeze(-1).expand(-1, 10), + torch.tensor([7.7, 5.4, 6.2, 8.0]).unsqueeze(-1).expand(-1, 10), + torch.tensor([1.5]).unsqueeze(-1).expand(-1, 10)] + expected_res1 = torch.tensor([ + [4.5, 2.3, 1.2, 0.0], + [6.7, 9.8, 0.0, 0.0], + [7.7, 5.4, 6.2, 8.0], + [1.5, 0.0, 0.0, 0.0]]).unsqueeze(-1).expand(-1, -1, 10) + expected_res2 = torch.tensor([ + [0.0, 4.5, 2.3, 1.2], + [0.0, 0.0, 6.7, 9.8], + [7.7, 5.4, 6.2, 8.0], + [0.0, 0.0, 0.0, 1.5]]).unsqueeze(-1).expand(-1, -1, 10) + + res = utils.collate_frames(vals, pad_value=0.0, left_pad=False) + self.assertTensorEqual(res, expected_res1) + + res = utils.collate_frames(vals, pad_value=0.0, left_pad=True) + self.assertTensorEqual(res, expected_res2) + + def test_sequence_mask(self): + seq_len = torch.tensor([1, 4, 0, 3]).int() + expected_mask = torch.tensor([ + [1, 0, 0, 0], + [1, 1, 1, 1], + [0, 0, 0, 0], + [1, 1, 1, 0]]).byte() + expected_mask2 = torch.tensor([ + [1, 0, 0, 0, 0], + [1, 1, 1, 1, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 0, 0]]).byte() + + generated_mask = utils.sequence_mask(seq_len) + generated_mask2 = utils.sequence_mask(seq_len, max_len=5) + + self.assertTensorEqual(generated_mask, expected_mask) + self.assertTensorEqual(generated_mask2, expected_mask2) + + def test_convert_padding_direction(self): + t1 = torch.tensor([ + [4.5, 2.3, 1.2, 0.0], + [6.7, 9.8, 0.0, 0.0], + [7.7, 5.4, 6.2, 8.0], + [1.5, 0.0, 0.0, 0.0]]).unsqueeze(-1).expand(-1, -1, 10) + t2 = torch.tensor([ + [0.0, 4.5, 2.3, 1.2], + [0.0, 0.0, 6.7, 9.8], + [7.7, 5.4, 6.2, 8.0], + [0.0, 0.0, 0.0, 1.5]]).unsqueeze(-1).expand(-1, -1, 10) + seq_len = torch.tensor([3, 2, 4, 1]).int() + + t1_to_t2 = utils.convert_padding_direction(t1, seq_len, + right_to_left=True) + self.assertTensorEqual(t1_to_t2, t2) + + t2_to_t1 = utils.convert_padding_direction(t2, seq_len, + left_to_right=True) + self.assertTensorEqual(t2_to_t1, t1) + + def test_edit_distance(self): + ref, hyp = [], [] + dist, steps, counter = utils.edit_distance(ref, hyp) + self.assertEqual(counter, + Counter({'words': 0, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0})) + self.assertEqual(steps, []) + + ref, hyp = ['a', 'b', 'c'], [] + dist, steps, counter = utils.edit_distance(ref, hyp) + self.assertEqual(counter, + Counter({'words': 3, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 3})) + self.assertEqual(steps, ['del', 'del', 'del']) + + ref, hyp = ['a', 'b', 'c'], ['a', 'b', 'c'] + dist, steps, counter = utils.edit_distance(ref, hyp) + self.assertEqual(counter, + Counter({'words': 3, 'corr': 3, 'sub': 0, 'ins': 0, 'del': 0})) + self.assertEqual(steps, ['corr', 'corr', 'corr']) + + ref, hyp = ['a', 'b', 'c'], ['d', 'b', 'c', 'e', 'f'] + dist, steps, counter = utils.edit_distance(ref, hyp) + self.assertEqual(counter, + Counter({'words': 3, 'corr': 2, 'sub': 1, 'ins': 2, 'del': 0})) + self.assertEqual(steps, ['sub', 'corr', 'corr', 'ins', 'ins']) + + ref, hyp = ['b', 'c', 'd', 'e', 'f', 'h'], \ + ['d', 'b', 'c', 'e', 'f', 'g'] + dist, steps, counter = utils.edit_distance(ref, hyp) + self.assertEqual(counter, + Counter({'words': 6, 'corr': 4, 'sub': 1, 'ins': 1, 'del': 1})) + self.assertEqual(steps, + ['ins', 'corr', 'corr', 'del', 'corr', 'corr', 'sub']) + + def assertTensorEqual(self, t1, t2): + self.assertEqual(t1.size(), t2.size(), "size mismatch") + if (t1.dtype == torch.short or t1.dtype == torch.int or \ + t1.dtype == torch.long or t1.dtype == torch.uint8) and \ + (t2.dtype == torch.short or t2.dtype == torch.int or \ + t2.dtype == torch.long or t2.dtype == torch.uint8): + self.assertEqual(t1.ne(t2).long().sum(), 0) + else: + self.assertEqual(t1.allclose(t2,rtol=1e-05, atol=1e-08), True) + if __name__ == "__main__": unittest.main() From eadf2a8cf89f534e1b5dffa8b7702361432e161e Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 12 Jan 2019 20:24:28 -0500 Subject: [PATCH 004/119] code adaptation/changes according to the commits from Dec 24, 2018 to Jan 9, 2019 --- fairseq/data/speech_dataset.py | 30 +++++++------- fairseq/models/speech_lstm.py | 3 +- fairseq/speech_recognizer.py | 11 +++-- fairseq/tasks/speech_recognition.py | 11 ++--- speech_recognition.py | 12 ++++-- speech_train.py | 62 ++++++++++++++++++++++++----- 6 files changed, 90 insertions(+), 39 deletions(-) diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index 2b1132970..5762ddc2d 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -90,19 +90,19 @@ class SpeechDataset(FairseqDataset): tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths dict (~fairseq.data.Dictionary, optional): target vocabulary - left_pad_source (bool, optional): pad source tensors on the left side. - Default: ``True`` - left_pad_target (bool, optional): pad target tensors on the left side. - Default: ``False`` + left_pad_source (bool, optional): pad source tensors on the left side + (default: True). + left_pad_target (bool, optional): pad target tensors on the left side + (default: False). max_source_positions (int, optional): max number of frames in the - source. Default: ``1024`` + source (default: 1024). max_target_positions (int, optional): max number of tokens in the target - sentence. Default: ``1024`` - shuffle (bool, optional): shuffle dataset elements before batching. - Default: ``True`` + sentence (default: 1024) + shuffle (bool, optional): shuffle dataset elements before batching + (default: True) input_feeding (bool, optional): create a shifted version of the targets - to be passed into the model for input feeding/teacher forcing. - Default: ``True`` + to be passed into the model for input feeding/teacher forcing + (default: True) """ def __init__( @@ -232,13 +232,11 @@ def ordered_indices(self): indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + @property + def supports_prefetch(self): + return getattr(self.src, 'supports_prefetch', False) + def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) - @property - def supports_prefetch(self): - return ( - hasattr(self.src, 'supports_prefetch') - and self.src.supports_prefetch - ) diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 3c2b41fd7..3d890c97c 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -267,7 +267,8 @@ def forward(self, src_tokens, src_lengths): if self.bidirectional: def combine_bidir(outs): - return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1) + out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() + return out.view(self.num_layers, bsz, -1) final_hiddens = combine_bidir(final_hiddens) final_cells = combine_bidir(final_cells) diff --git a/fairseq/speech_recognizer.py b/fairseq/speech_recognizer.py index 97905d142..3e283d942 100644 --- a/fairseq/speech_recognizer.py +++ b/fairseq/speech_recognizer.py @@ -20,11 +20,14 @@ def generate_batched_itr( cuda=False, timer=None, prefix_size=0, ): """Iterate over a batched dataset and yield individual transcription. + Args: - maxlen_a/b: generate sequences of maximum length ax + b, - where x is the source sentence length. - cuda: use GPU for generation - timer: StopwatchMeter for timing generations. + maxlen_a/b (int, optional): generate sequences of maximum length + ``ax + b``, where ``x`` is the source sentence length. + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + prefix_size (int, optional): prefill the generation with the gold + prefix up to this length. """ if maxlen_b is None: maxlen_b = self.maxlen diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 092c7595f..f30ead609 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -6,14 +6,17 @@ # can be found in the PATENTS file in the same directory. import itertools -import numpy as np import os import re from fairseq import options, utils from fairseq.data import ( - data_utils, TokenDictionary, SpeechDataset, ConcatDataset, - TokenTextDataset, ScpCachedDataset + ConcatDataset, + data_utils, + ScpCachedDataset, + SpeechDataset, + TokenDictionary, + TokenTextDataset, ) from . import FairseqTask, register_task @@ -57,8 +60,6 @@ def add_args(parser): help='path(s) to text file(s) for test') parser.add_argument('--dict', default=None, type=str, help='path to the dictionary') - parser.add_argument('--raw-text', action='store_true', - help='load raw text dataset') parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', help='pad the source on the left') parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', diff --git a/speech_recognition.py b/speech_recognition.py index ad7ce6ff9..54d413fa0 100644 --- a/speech_recognition.py +++ b/speech_recognition.py @@ -13,7 +13,7 @@ import torch -from fairseq import wer, data, options, progress_bar, tasks, utils +from fairseq import wer, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter from fairseq.speech_recognizer import SpeechRecognizer @@ -40,8 +40,9 @@ def main(args): # Load ensemble print('| loading model(s) from {}'.format(args.path)) - models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, - model_arg_overrides=eval(args.model_overrides)) + models, _model_args = utils.load_ensemble_for_inference( + args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), + ) # Optimize ensemble for generation for model in models: @@ -65,9 +66,13 @@ def main(args): required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, + num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator + if args.match_source_len: + print('| The option match_source_len is not applicable to ' + 'speech recognition. Ignoring it.') gen_timer = StopwatchMeter() recognizer = SpeechSequenceGenerator( models, dict, beam_size=args.beam, minlen=args.min_len, @@ -78,6 +83,7 @@ def main(args): sampling_temperature=args.sampling_temperature, diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength, + match_source_len=False, no_repeat_ngram_size=args.no_repeat_ngram_size, ) if use_cuda: diff --git a/speech_train.py b/speech_train.py index d361399c6..61487c7c3 100644 --- a/speech_train.py +++ b/speech_train.py @@ -13,6 +13,8 @@ import itertools import os import math +import random + import torch from fairseq import distributed_utils, options, progress_bar, tasks, utils @@ -26,9 +28,8 @@ def main(args): args.max_tokens = 6000 print(args) - if not torch.cuda.is_available(): - raise NotImplementedError('Training on CPU is not supported') - torch.cuda.set_device(args.device_id) + if torch.cuda.is_available() and not args.cpu: + torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) # Setup task, e.g., translation, language modeling, etc. @@ -72,6 +73,7 @@ def main(args): seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, + num_workers=args.num_workers, ) # Load the latest checkpoint if one is available @@ -210,6 +212,7 @@ def validate(args, trainer, task, epoch_itr, subsets): seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, + num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, @@ -309,7 +312,15 @@ def save_checkpoint(args, trainer, epoch_itr, val_wer): # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: - os.remove(old_chk) + if os.path.lexists(old_chk): + os.remove(old_chk) + + if args.keep_last_epochs > 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt') + for old_chk in checkpoints[args.keep_last_epochs:]: + if os.path.lexists(old_chk): + os.remove(old_chk) def load_checkpoint(args, trainer, epoch_itr): @@ -349,17 +360,48 @@ def load_dataset_splits(task, splits): raise e +def distributed_main(i, args): + import socket + args.device_id = i + if args.distributed_rank is None: # torch.multiprocessing.spawn + args.distributed_rank = i + args.distributed_rank = distributed_utils.distributed_init(args) + print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) + main(args) + + if __name__ == '__main__': parser = options.get_training_parser() args = options.parse_args_and_arch(parser) - if args.distributed_port > 0 or args.distributed_init_method is not None: - from distributed_train import main as distributed_main + if args.distributed_init_method is None: + distributed_utils.infer_init_method(args) - distributed_main(args) + if args.distributed_init_method is not None: + # distributed training + distributed_main(args.device_id, args) elif args.distributed_world_size > 1: - from multiprocessing_train import main as multiprocessing_main - - multiprocessing_main(args) + # fallback for single node with multiple GPUs + port = random.randint(10000, 20000) + args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + args.distributed_rank = None # set based on device id + print( + '''| NOTE: you may get better performance with: + + python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...) + '''.format( + ngpu=args.distributed_world_size, + no_c10d=( + '--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d' + else '' + ), + ) + ) + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args, ), + nprocs=args.distributed_world_size, + ) else: + # single GPU training main(args) From 6148a8a7436bd6e574ebc39f4bc69a48d98ba2bc Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 14 Jan 2019 17:48:07 -0500 Subject: [PATCH 005/119] wsj recipe and other fixes --- examples/asr_wsj/path.sh | 7 + examples/asr_wsj/run.sh | 94 ++++++++++ fairseq/criterions/cross_entropy_with_wer.py | 89 +++++++--- fairseq/data/scp_dataset.py | 14 +- fairseq/data/speech_dataset.py | 16 +- fairseq/data/token_dictionary.py | 21 ++- fairseq/models/speech_lstm.py | 177 +++++++++++-------- fairseq/modules/speech_attention.py | 20 +-- fairseq/tasks/speech_recognition.py | 64 ++++--- fairseq/wer.py | 77 +++++--- speech_recognition.py | 59 ++++--- speech_tools/__init__.py | 0 speech_tools/parse_options.sh | 97 ++++++++++ speech_tools/text2token.py | 48 +++++ speech_tools/utils.py | 41 ++++- speech_train.py | 26 ++- tests/test_speech_dataset.py | 2 + tests/test_speech_utils.py | 2 + 18 files changed, 651 insertions(+), 203 deletions(-) create mode 100644 examples/asr_wsj/path.sh create mode 100755 examples/asr_wsj/run.sh create mode 100644 speech_tools/__init__.py create mode 100755 speech_tools/parse_options.sh create mode 100755 speech_tools/text2token.py diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh new file mode 100644 index 000000000..731d834f6 --- /dev/null +++ b/examples/asr_wsj/path.sh @@ -0,0 +1,7 @@ +export PATH=~/anaconda3/bin:$PATH + +MAIN_ROOT=$PWD/../.. + +export PATH=$MAIN_ROOT:$PATH +export LC_ALL=C + diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh new file mode 100755 index 000000000..1af8af9bf --- /dev/null +++ b/examples/asr_wsj/run.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Copyright (c) 2019-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +set -e -o pipefail + +stage=0 +free_gpu= +data_dir=data-bin/wsj +exp_dir=exp/wsj/lstm +train_set=train_si284 +valid_set=test_dev93 +test_set=test_eval92 +checkpoint=checkpoint_best.pt + + +if [ -f ./path.sh ]; then + . ./path.sh +else + . ./examples/asr_wsj/path.sh +fi +if [ -f ../../speech_tools/parse_options.sh ]; then + . ../../speech_tools/parse_options.sh +else + . ./speech_tools/parse_options.sh +fi + +dict=$data_dir/lang/${train_set}_units.txt +nlsyms=$data_dir/lang/non_lang_syms.txt +train_text=$data_dir/$train_set/text +train_token_text=$data_dir/$train_set/token_text +valid_text=$data_dir/$valid_set/text +valid_token_text=$data_dir/$valid_set/token_text +test_text=$data_dir/$test_set/text +test_token_text=$data_dir/$test_set/token_text +if [ ${stage} -le 1 ]; then + echo "Stage 1: Dictionary Preparation and Text Tokenization" + mkdir -p $data_dir/lang + + echo "Making a non-linguistic symbol list..." + cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "<" > $nlsyms + cat $nlsyms + + echo "Making a dictionary and tokenizing text for training set..." + python3 speech_tools/text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms \ + $train_text > $train_token_text + cut -f 2- -d" " $train_token_text | tr " " "\n" | grep -v -e '^\s*$' | sort | \ + uniq -c | awk '{print $2,$1}' > $dict + wc -l $dict + + echo "Tokenizing text for validation and test set..." + python3 speech_tools/text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms \ + $valid_text > $valid_token_text + python3 speech_tools/text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms \ + $test_text > $test_token_text +fi + +train_feat=$data_dir/dump/$train_set/deltafalse/feats.scp +valid_feat=$data_dir/dump/$valid_set/deltafalse/feats.scp +if [ ${stage} -le 2 ]; then + echo "Stage 2: Model Training" + mkdir -p $exp_dir/logs + [ -z "$free_gpu" ] && free_gpu=$(free-gpu) + CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_train.py \ + --log-interval 1000 --log-format "simple" --seed 1 \ + --num-workers 0 --max-tokens 45000 --max-sentences 32 --max-sentences-valid 64 \ + --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ + --max-epoch 20 --optimizer "adam" --lr 0.1 --weight-decay 0.0 \ + --lr-scheduler "reduce_lr_on_plateau" --lr-shrink 0.1 \ + --save-dir $exp_dir --save-interval-updates 100 --keep-interval-updates 10 \ + --keep-last-epochs 5 --validate-interval 100 \ + --arch "speech_conv_lstm_wsj" --criterion "cross_entropy_with_wer" \ + --train-feat-files $train_feat --train-text-files $train_token_text \ + --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ + --dict $dict --max-source-positions 9999 --max-target-positions 999 2>&1 | tee $exp_dir/logs/train.log +fi +exit 0 + +test_feat=$data_dir/dump/$test_set/deltafalse/feats.scp +if [ ${stage} -le 2 ]; then + echo "Stage 3: Decoding" + [ -z "$free_gpu" ] && free_gpu=$(free-gpu) + CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_recognition.py \ + --max-tokens 45000 --max-sentences 32 --num-shards 1 --shard-id 0 \ + --test-feat-files $test_feat --test-text-files $test_token_text \ + --dict $dict --max-source-positions 9999 --max-target-positions 999 \ + --path $exp_dir/$checkpoint --beam 10 --max-len-a 0.5 --max-len-b 0 \ + --lenpen 1.0 --print-alignment 2>&1 | tee $exp_dir/logs/decode.log +fi diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 79626fa79..4aac9efea 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-present, Facebook, Inc. +# Copyright (c) 2018-present, Yiming Wang # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in @@ -6,9 +6,11 @@ # can be found in the PATENTS file in the same directory. import math +import numpy as np import torch.nn.functional as F from fairseq import utils, wer +from fairseq.data import data_utils from . import FairseqCriterion, register_criterion from .cross_entropy import CrossEntropyCriterion @@ -19,9 +21,21 @@ class CrossEntropyWithWERCriterion(CrossEntropyCriterion): def __init__(self, args, task): super().__init__(args, task) - dict = self.task.dict if hasattr(self.task, 'dict') \ - else self.task.tgt_dict + + dict = task.dict if hasattr(task, 'dict') else getattr(task, 'tgt_dict') self.scorer = wer.Scorer(dict) + self.num_calls = 0 #getattr(task, 'iterations_in_epoch', 0) + self.train_tgt_dataset = task.dataset(args.train_subset).tgt + self.valid_tgt_dataset = None + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--sample-results-interval', type=int, metavar='N', default=500, + help='print sample results interval every this ' + 'number of forward times') + # fmt: on def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -33,19 +47,49 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample['net_input']) lprobs = model.get_normalized_probs(net_output, log_probs=True) - # wer code starts - if not model.training: - pred = lprobs.argmax(-1).int().cpu() # bsz x len + # wer stats code starts + if not model.training or \ + self.num_calls % self.args.sample_results_interval == 0: + pred = lprobs.argmax(-1).long().cpu() # bsz x len + target = sample['target'].long().cpu() # bsz x len assert pred.size() == sample['net_input']['prev_output_tokens'].size() - assert pred.size() == sample['target'].size() - dict = self.task.dict if hasattr(self.task, 'dict') \ - else self.task.tgt_dict - self.scorer.reset() - ref_str_list = dict.string(sample['target'].int().cpu()).split('\n') - pred_str_list = dict.string(pred).split('\n') - for ref_str, pred_str in zip(ref_str_list, pred_str_list): - scorer.add(ref_str, pred_str) - # wer code ends + assert pred.size() == target.size() + + def lengths_strip_padding(idx_array, padding_idx): + # assume sequences are right-padded, so work out the length by + # looking for the first occurence of padding_idx + assert idx_array.ndim == 1 or idx_array.ndim == 2 + if idx_array.ndim == 1: + try: + return idx_array.tolist().index(padding_idx) + except ValueError: + return len(idx_array) + return [lengths_strip_padding(row, padding_idx) for row in idx_array] + + target_lengths = lengths_strip_padding(target.numpy(), + self.padding_idx) + dict = self.scorer.dict + if not model.training: # validation step, compute WER stats with scorer + self.scorer.reset() + for i, length in enumerate(target_lengths): + utt_id = sample['utt_id'][i] + id = sample['id'].data[i] + #ref_str = dict.string(target.data[i]) + ref_str = self.valid_tgt_dataset.get_original_tokens(id) + pred_str = dict.string(pred.data[i][:length]) + self.scorer.add_evaluation(utt_id, ref_str, pred_str) + else: # print a randomly sampled result every sample_results_interval batch + with data_utils.numpy_seed(self.num_calls): + i = np.random.randint(0, len(sample['id'])) + id = sample['id'].data[i] + #ref_str_one = dict.string(target.data[i]) + ref_str_one = self.train_tgt_dataset.get_original_tokens(id) + pred_str_one = dict.string(pred.data[i][:target_lengths[i]]) + print('| ' + sample['utt_id'][i]) + print('| sample REF: ' + ref_str_one) + print('| sample PRD: ' + pred_str_one) + self.num_calls += 1 + # wer stats code ends lprobs = lprobs.view(-1, lprobs.size(-1)) target = model.get_targets(sample, net_output).view(-1) loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, @@ -57,20 +101,21 @@ def forward(self, model, sample, reduce=True): 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } - if not model.training: - logging_output['word_error'] = scorer.acc_word_error() - logging_output['word_count'] = scorer.acc_word_count() + if not model.training: # do not compute word error in training mode + logging_output['word_error'] = self.scorer.tot_word_error() + logging_output['word_count'] = self.scorer.tot_word_count() return loss, sample_size, logging_output @staticmethod def aggregate_logging_outputs(logging_outputs): """Aggregate logging outputs from data parallel training.""" - agg_output = super().aggregate_logging_outputs(logging_outputs) + agg_output = CrossEntropyCriterion.aggregate_logging_outputs(logging_outputs) word_error = sum(log.get('word_error', 0) for log in logging_outputs) word_count = sum(log.get('word_count', 0) for log in logging_outputs) - if word_count > 0: + if word_count > 0: # model.training == False agg_output['word_error'] = word_error agg_output['word_count'] = word_count - else: - print('Not aggregating WER in training mode.') return agg_output + + def set_valid_tgt_dataset(self, dataset): + self.valid_tgt_dataset = dataset diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index 761a59b96..d4aadddcd 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -101,16 +101,17 @@ def prefetch(self, indices): assert isinstance(indices, (list, np.ndarray)) assert self.size >= len(indices) self.ordered_indices = indices.copy() + self.start_pos_for_next_cache = 0 def __getitem__(self, i): self.check_index(i) if i not in self.cache_index: - assert self.start_search_for_next_pos_start < \ + assert self.start_pos_for_next_cache < \ len(self.ordered_indices), \ - 'Search position starting beyond the end of ordered_indices.' + 'Position for next cache starting beyond the end of ordered_indices.' try: pos_start = self.ordered_indices.index(i, - self.start_search_for_next_pos_start) + self.start_pos_for_next_cache) except ValueError: print('index {} not found in self.ordered_indices. Set ' 'self.ordered_prefetch to False, and/or call self.prefetch() ' @@ -118,7 +119,7 @@ def __getitem__(self, i): raise pos_end = min(pos_start + self.cache_size, len(self.ordered_indices)) - self.start_search_for_next_pos_start = pos_end \ + self.start_pos_for_next_cache = pos_end \ if self.ordered_prefetch else 0 total_size = 0 for idx in self.ordered_indices[pos_start : pos_end]: @@ -221,9 +222,10 @@ def get_original_tokens(self, i): self.check_index(i) return self.tokens_list[i] - def get_original_text(self, i): + def get_original_text(self, i, dictionary): self.check_index(i) - return Tokenizer.tokens_to_sentence(self.tokens_list[i]) + return Tokenizer.tokens_to_sentence(self.tokens_list[i], dictionary, + use_unk_sym=False) def __len__(self): return self.size diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index 5762ddc2d..17deae6aa 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -61,7 +61,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: - ntokens = sum(len(s['source']) for s in samples) + ntokens = sum(s['source'].size(0) for s in samples) batch = { 'id': id, @@ -108,7 +108,7 @@ class SpeechDataset(FairseqDataset): def __init__( self, src, src_sizes, tgt=None, tgt_sizes=None, dict=None, - left_pad_source=True, left_pad_target=False, + left_pad_source=False, left_pad_target=False, max_source_positions=1024, max_target_positions=1024, shuffle=True, input_feeding=True, ): @@ -136,6 +136,7 @@ def _match_src_tgt(self): src_indices = [i for i, id in enumerate(self.src.utt_ids) \ if id in tgt_utt_ids_set] self.src.filter_and_reorder(src_indices) + self.src_sizes = np.array(self.src.sizes) try: tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: @@ -143,6 +144,7 @@ def _match_src_tgt(self): happen. Something must be wrong.') raise self.tgt.filter_and_reorder(tgt_indices) + self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids def __getitem__(self, index): @@ -183,7 +185,7 @@ def collater(self, samples): is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - - `target` (IntTensor): a padded 2D Tensor of tokens in the + - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. """ @@ -193,14 +195,14 @@ def collater(self, samples): input_feeding=self.input_feeding, ) - def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): + def get_dummy_batch(self, num_tokens, max_positions, max_sentences=16, src_len=300, tgt_len=30): """Return a dummy batch with a given number of tokens.""" src_len, tgt_len = utils.resolve_max_positions( (src_len, tgt_len), max_positions, (self.max_source_positions, self.max_target_positions), ) - bsz = max(num_tokens // tgt_len, 1) + bsz = max(min(num_tokens // src_len, max_sentences), 1) return self.collater([ { 'id': i, @@ -212,9 +214,9 @@ def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): ]) def num_tokens(self, index): - """Return the number of tokens in a sample. This value is used to + """Return the number of frames in a sample. This value is used to enforce ``--max-tokens`` during batching.""" - return self.tgt_sizes[index] if self.tgt_sizes is not None else 0 + return self.src_sizes[index] def size(self, index): """Return an example's size as a float or tuple. This value is used when diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 026ceb0fb..514b406bd 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -21,7 +21,6 @@ def __init__(self, pad='', eos='', unk='', space=''): self.pad_index = self.add_symbol(pad) self.eos_index = self.add_symbol(eos) self.unk_index = self.add_symbol(unk) - self.space_index = self.add_symbol(space) self.nspecial = len(self.symbols) def string(self, tensor, bpe_symbol=None, escape_unk=False): @@ -50,8 +49,24 @@ def space(self): """Helper to get index of space symbol""" return self.space_index + @classmethod + def load(cls, f, ignore_utf_errors=False): + """Loads the dictionary from a text file with the format: + + ``` + + + ... + ``` + and identifies the space symbol if it exists, by obtaining its index + (space_index=-1 if no space symbol) + """ + d = super().load(f, ignore_utf_errors) + d.space_index = d.indices.get(d.space_word, -1) + return d + def dummy_sentence(self, length): - # sample starting from space - t = torch.Tensor(length).uniform_(self.nspecial - 1, len(self)).int() + # sample excluding special symbols + t = torch.Tensor(length).uniform_(self.nspecial, len(self)).long() t[-1] = self.eos() return t diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 3d890c97c..74d535b48 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -29,16 +29,15 @@ def __init__(self, encoder, decoder): @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" + # fmt: off parser.add_argument('--dropout', type=float, metavar='D', help='dropout probability') - parser.add_argument('feat-dim', type=int, metavar='N', - help='input feature dimension') parser.add_argument('--encoder-conv-channels', type=str, metavar='STR', help='list of encoder convolution\'s out channels') - parser.add_argument('--encoder-conv-kernel-size', type=str, metavar='STR', - help='list of encoder convolution\'s kernel size') - parser.add_argument('--encoder-conv-stride', type=str, metavar='STR', - help='list of encoder convolution\'s stride') + parser.add_argument('--encoder-conv-kernel-sizes', type=str, metavar='STR', + help='list of encoder convolution\'s kernel sizes') + parser.add_argument('--encoder-conv-strides', type=str, metavar='STR', + help='list of encoder convolution\'s strides') parser.add_argument('--encoder-rnn-hidden-size', type=int, metavar='N', help='encoder rnn\'s hidden size') parser.add_argument('--encoder-rnn-layers', type=int, metavar='N', @@ -76,6 +75,7 @@ def add_args(parser): parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', help='share decoder input and output embeddings') + # fmt: on @classmethod def build_model(cls, args, task): @@ -107,19 +107,44 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): '--decoder-embed-dim to match --decoder-out-embed-dim' ) - out_channels = options.eval_str_list(args.encoder_conv_channels, + def eval_str_nested_list_or_tuple(x, type=int): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + if isinstance(x, list): + return list( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + elif isinstance(x, tuple): + return tuple( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + else: + try: + return type(x) + except: + raise ValueError + + out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) - kernel_size = options.eval_str_list(args.encoder_conv_kernel_size, + kernel_sizes = eval_str_nested_list_or_tuple( + args.encoder_conv_kernel_sizes, type=int) + strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) - stride = options.eval_str_list(args.encoder_conv_stride, type=int) - in_channel = 1 # hard-coded for now - conv_layers = ConvBNReLU(out_channels, kernel_size, stride, - in_channel=in_channel) if not out_channels is None else None + in_channels = 1 # hard-coded for now + conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, + in_channels=in_channels) if not out_channels is None else None - rnn_encoder_input_size = args.feat_dim // in_channel + assert task.feat_dim % in_channels == 0 + rnn_encoder_input_size = task.feat_dim // in_channels if conv_layers is not None: - for s in stride: - rnn_encoder_input_size = (rnn_input_size + s[1] - 1) // s[1] + for stride in strides: + if isinstance(stride, (list, tuple)): + assert len(stride) > 0 + s = stride[1] if len(stride) > 1 else stride[0] + else: + assert isinstance(stride, int) + s = stride + rnn_encoder_input_size = (rnn_encoder_input_size + s - 1) // s rnn_encoder_input_size *= out_channels[-1] encoder = SpeechLSTMEncoder( @@ -154,30 +179,42 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): class ConvBNReLU(nn.Module): """Sequence of convolution-BatchNorm-ReLU layers.""" - def __init__(self, out_channels, kernel_size, stride, in_channel=1): + def __init__(self, out_channels, kernel_sizes, strides, in_channels=1): super().__init__() self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.in_channel = in_channel + self.kernel_sizes = kernel_sizes + self.strides = strides + self.in_channels = in_channels - self.num_layers = len(out_channels) - assert num_layers == len(kernel_size) and num_layers == len(stride) + num_layers = len(out_channels) + assert num_layers == len(kernel_sizes) and num_layers == len(strides) self.convolutions = nn.ModuleList() self.batchnorms = nn.ModuleList() - for i in range(self.num_layers): + for i in range(num_layers): self.convolutions.append( Convolution2d( - self.in_channel if i == 0 else self.out_channels[i-1], + self.in_channels if i == 0 else self.out_channels[i-1], self.out_channels[i], - self.kernel_size[i], self.stride[i])) + self.kernel_sizes[i], self.strides[i])) self.batchnorms.append(nn.BatchNorm2d(out_channels[i])) + def output_lengths(self, in_lengths): + out_lengths = in_lengths + for stride in self.strides: + if isinstance(stride, (list, tuple)): + assert len(stride) > 0 + s = stride[0] + else: + assert isinstance(stride, int) + s = stride + out_lengths = (out_lengths + s - 1) // s + return out_lengths + def forward(self, src, src_lengths): # B X T X C -> B X (input channel num) x T X (C / input channel num) - x = src.view(src.size(0), src.size(1), self.in_channel, - src.size(2) // self.in_channel).transpose(1, 2) + x = src.view(src.size(0), src.size(1), self.in_channels, + src.size(2) // self.in_channels).transpose(1, 2) for conv, bn in zip(self.convolutions, self.batchnorms): x = F.relu(bn(conv(x))) # B X (output channel num) x T X C' -> B X T X (output channel num) X C' @@ -185,9 +222,7 @@ def forward(self, src, src_lengths): # B X T X (output channel num) X C' -> B X T X C x = x.contiguous().view(x.size(0), x.size(1), x.size(2) * x.size(3)) - x_lengths = src_lengths - for i in range(self.num_layers): - x_lengths = (x_lengths + self.stride[0] - 1) // self.stride[0] + x_lengths = self.output_lengths(src_lengths) padding_mask = 1 - speech_utils.sequence_mask(x_lengths, x.size(1)) if padding_mask.any(): x = x.masked_fill(padding_mask.unsqueeze(-1), 0.0) @@ -198,7 +233,7 @@ def forward(self, src, src_lengths): class SpeechLSTMEncoder(FairseqEncoder): """LSTM encoder.""" def __init__( - self, conv_layers_before=None, input_size=40, hidden_size=512, + self, conv_layers_before=None, input_size=80, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, left_pad=False, pretrained_embed=None, padding_value=0., ): @@ -224,6 +259,9 @@ def __init__( if bidirectional: self.output_units *= 2 + def output_lengths(self, in_lengths): + return in_lengths if self.conv_layers_before is None \ + else self.conv_layers_before.output_lengths(in_lengths) def forward(self, src_tokens, src_lengths): if self.left_pad: @@ -238,7 +276,8 @@ def forward(self, src_tokens, src_lengths): x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: - x = src_tokens + x, padding_mask = src_tokens, \ + 1 - speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) bsz, seqlen = x.size(0), x.size(1) @@ -328,16 +367,16 @@ def __init__( ) for layer in range(num_layers) ]) - if attn.type == 'bahdanau': + if attn_type == 'bahdanau': self.attention = speech_attention.BahdanauAttention(hidden_size, encoder_output_units, attn_dim) - elif attn.type == 'luong': + elif attn_type == 'luong': self.attention = speech_attention.LuongAttention(hidden_size, encoder_output_units) else: raise ValueError('unrecognized attention type.') - if hidden_size != out_embed_dim: - self.additional_fc = Linear(hidden_size, out_embed_dim) + if hidden_size + encoder_output_units != out_embed_dim: + self.additional_fc = Linear(hidden_size + encoder_output_units, out_embed_dim) if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, embed_dim, adaptive_softmax_cutoff, @@ -458,18 +497,20 @@ def make_generation_fast_(self, need_attn=False, **kwargs): def Convolution2d(in_channels, out_channels, kernel_size, stride): - if len(kernel_size) != 2: - if len(kernel_size) == 1: + if isinstance(kernel_size, (list, tuple)): + if len(kernel_size) != 2: + assert len(kernel_size) == 1 kernel_size = (kernel_size[0], kernel_size[0]) - else: - assert isinstance(kernel_size, int) - kernel_size = (kernel_size, kernel_size) - if len(stride) != 2: - if len(stride) == 1: + else: + assert isinstance(kernel_size, int) + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, (list, tuple)): + if len(stride) != 2: + assert len(stride) == 1 stride = (stride[0], stride[0]) - else: - assert isinstance(stride, int) - stride = (stride, stride) + else: + assert isinstance(stride, int) + stride = (stride, stride) assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) m = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \ @@ -480,27 +521,32 @@ def Convolution2d(in_channels, out_channels, kernel_size, stride): @register_model_architecture('speech_lstm', 'speech_lstm') def base_architecture(args): args.dropout = getattr(args, 'dropout', 0.1) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) - args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim) - args.encoder_layers = getattr(args, 'encoder_layers', 1) - args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False) - args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout) - args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', + '[64, 64, 128, 128]') + args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', + '[(3, 3), (3, 3), (3, 3), (3, 3)]') + args.encoder_conv_strides = getattr(args, 'encoder_conv_strides', + '[(1, 1), (2, 2), (1, 1), (2, 2)]') + args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 320) + args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) + args.encoder_rnn_bidirectional = getattr(args, 'encoder_rnn_bidirectional', True) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 320) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim) - args.decoder_layers = getattr(args, 'decoder_layers', 1) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) - args.decoder_attention = getattr(args, 'decoder_attention', '1') + args.decoder_layers = getattr(args, 'decoder_layers', 3) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 960) + args.attention_type = getattr(args, 'attention_type', 'bahdanau') + args.attention_dim = getattr(args, 'attention_dim', 320) + args.encoder_rnn_dropout_in = getattr(args, 'encoder_rnn_dropout_in', args.dropout) + args.encoder_rnn_dropout_out = getattr(args, 'encoder_rnn_dropout_out', args.dropout) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') -@register_model_architecture('speech_lstm', 'speech_lstm_wiseman_iwslt_de_en') -def lstm_wiseman_iwslt_de_en(args): +@register_model_architecture('speech_lstm', 'speech_conv_lstm_wsj') +def conv_lstm_wsj(args): + ''' args.dropout = getattr(args, 'dropout', 0.1) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0) @@ -509,16 +555,5 @@ def lstm_wiseman_iwslt_de_en(args): args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - base_architecture(args) - - -@register_model_architecture('speech_lstm', 'speech_lstm_luong_wmt_en_de') -def lstm_luong_wmt_en_de(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) - args.encoder_layers = getattr(args, 'encoder_layers', 4) - args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1000) - args.decoder_layers = getattr(args, 'decoder_layers', 4) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1000) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', 0) + ''' base_architecture(args) diff --git a/fairseq/modules/speech_attention.py b/fairseq/modules/speech_attention.py index 61d823cf7..f0e083a4c 100644 --- a/fairseq/modules/speech_attention.py +++ b/fairseq/modules/speech_attention.py @@ -38,8 +38,8 @@ class BahdanauAttention(BaseAttention): def __init__(self, query_dim, value_dim, embed_dim, normalize=True): super().__init__(query_dim, value_dim, embed_dim) - self.query_proj = nn.Linear(self.query_dim, self.embed_dim, bias=False) - self.value_proj = nn.Linear(self.value_dim, self.embed_dim, bias=False) + self.query_proj = nn.Linear(self.query_dim, embed_dim, bias=False) + self.value_proj = nn.Linear(self.value_dim, embed_dim, bias=False) self.v = Parameter(torch.Tensor(embed_dim)) self.normalize = normalize if self.normalize: @@ -54,7 +54,7 @@ def reset_parameters(self): nn.init.uniform_(self.v, -0.1, 0.1) if self.normalize: nn.init.constant_(self.b, 0.) - nn.init.constant_(self.g, math.sqrt(1. / embed_dim)) + nn.init.constant_(self.g, math.sqrt(1. / self.embed_dim)) def forward(self, query, value, key_padding_mask=None, state=None): # projected_query: 1 x bsz x embed_dim @@ -63,14 +63,14 @@ def forward(self, query, value, key_padding_mask=None, state=None): if self.normalize: # normed_v = g * v / ||v|| normed_v = self.g * self.v / torch.norm(self.v) - attn_scores = (normed_v * nn.tanh(projected_query + key + \ + attn_scores = (normed_v * torch.tanh(projected_query + key + \ self.b)).sum(dim=2) # len x bsz else: - attn_scores = v * nn.tanh(projected_query + key).sum(dim=2) + attn_scores = v * torch.tanh(projected_query + key).sum(dim=2) - if encoder_padding_mask is not None: + if key_padding_mask is not None: attn_scores = attn_scores.float().masked_fill_( - encoder_padding_mask, float('-inf'), + key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back attn_scores = F.softmax(attn_scores, dim=0) # len x bsz @@ -100,16 +100,16 @@ def reset_parameters(self): nn.init.constant_(self.g, 1.) def forward(self, query, value, key_padding_mask=None, state=None): - query = self.query_proj(query).unsqueeze(1) # bsz x 1 x query_dim + query = query.unsqueeze(1) # bsz x 1 x query_dim key = self.value_proj(value).transpose(0, 1) # bsz x len x query_dim attn_scores = torch.bmm(query, key.transpose(1, 2)).squeeze(1) attn_scores = attn_scores.transpose(0, 1) # len x bsz if self.scale: attn_scores = self.g * attn_scores - if encoder_padding_mask is not None: + if key_padding_mask is not None: attn_scores = attn_scores.float().masked_fill_( - encoder_padding_mask, float('-inf'), + key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back attn_scores = F.softmax(attn_scores, dim=0) # len x bsz diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index f30ead609..23e21fae2 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -46,26 +46,29 @@ class SpeechRecognitionTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" - parser.add_argument('--train-scp-files', nargs='+', - help='path(s) to scp file(s) for training') + parser.add_argument('--train-feat-files', nargs='+', + help='path(s) to scp feature file(s) for training') parser.add_argument('--train-text-files', nargs='+', - help='path(s) to text file(s) for training') - parser.add_argument('--valid-scp-files', nargs='+', - help='path(s) to scp file(s) for validation') + help='path(s) to text file(s) for training, where ' + 'each should matches with one in --train-feat-files') + parser.add_argument('--valid-feat-files', nargs='+', + help='path(s) to scp feature file(s) for validation') parser.add_argument('--valid-text-files', nargs='+', - help='path(s) to text file(s) for validation') - parser.add_argument('--test-scp-files', nargs='+', - help='path(s) to scp file(s) for test') + help='path(s) to text file(s) for validation, where ' + 'each should matches with one in --valid-feat-files') + parser.add_argument('--test-feat-files', nargs='+', + help='path(s) to scp feature file(s) for test') parser.add_argument('--test-text-files', nargs='+', - help='path(s) to text file(s) for test') + help='path(s) to text file(s) for test, where ' + 'each should matches with one in --test-feat-files') parser.add_argument('--dict', default=None, type=str, help='path to the dictionary') - parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', + parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', help='pad the source on the left') parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', help='pad the target on the left') parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the source sequence') + help='max number of frames in the source sequence') parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', help='max number of tokens in the target sequence') parser.add_argument('--upsample-primary', default=1, type=int, @@ -77,7 +80,7 @@ def load_pretrained_model(path, dict_path, arg_overrides=None): args = model['args'] state_dict = model['model'] args = utils.override_model_args(args, arg_overrides) - dict = Dictionary.load(dict_path) + dict = TokenDictionary.load(dict_path) task = SpeechRecognitionTask(args, dict) model = task.build_model(args) @@ -88,6 +91,7 @@ def load_pretrained_model(path, dict_path, arg_overrides=None): def __init__(self, args, dict): super().__init__(args) self.dict = dict + self.iterations_in_epoch = 0 @classmethod def setup_task(cls, args, **kwargs): @@ -117,26 +121,31 @@ def load_dataset(self, split, combine=False, **kwargs): tgt_datasets = [] if split == 'train': - scp_files = self.args.train_scp_files + feat_files = self.args.train_feat_files text_files = self.args.train_text_files - assert len(scp_files) > 0 and len(text_files) > 0 + assert len(feat_files) > 0 and len(text_files) > 0 elif re.match(r"^valid\d*$", split): - scp_files = self.args.valid_scp_files - text_files = self.args.valid_text_files - assert len(scp_files) > 0 and len(text_files) > 0 + m = re.match(r"^valid(\d*)$", split) + idx = 0 if m.group(1) == '' else int(m.group(1)) + if idx >= len(self.args.valid_feat_files) or \ + idx >= len(self.args.valid_text_files): + raise FileNotFoundError + feat_files = [self.args.valid_feat_files[idx]] + text_files = [self.args.valid_text_files[idx]] + assert len(feat_files) > 0 and len(text_files) > 0 elif split == 'test': - scp_files = self.args.test_scp_files + feat_files = self.args.test_feat_files text_files = self.args.test_text_files - assert len(scp_files) > 0 and len(text_files) > 0 + assert len(feat_files) > 0 and len(text_files) > 0 else: raise ValueError('split should be one of "train", "valid*", "test"') - assert len(scp_files) == len(text_files) - file_pairs = zip(scp_files, text_files) - for scp, text in enumerate(file_pairs): - assert ScpCachedDataset.exists(scp) and TokenTextDataset.exists(text) - src_datasets.append(ScpCachedDataset(scp, ordered_indices=True)) + assert len(feat_files) == len(text_files) + file_pairs = zip(feat_files, text_files) + for feat, text in file_pairs: + assert ScpCachedDataset.exists(feat) and TokenTextDataset.exists(text) + src_datasets.append(ScpCachedDataset(feat, ordered_prefetch=True)) tgt_datasets.append(TokenTextDataset(text, self.dict)) - print('| {} {} examples'.format(scp, len(src_datasets[-1]))) + print('| {} {} examples'.format(feat, len(src_datasets[-1]))) print('| {} {} examples'.format(text, len(tgt_datasets[-1]))) if not combine: @@ -144,9 +153,14 @@ def load_dataset(self, split, combine=False, **kwargs): assert len(src_datasets) == len(tgt_datasets) + self.feat_dim = src_datasets[0].feat_dim + if len(src_datasets) == 1: src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] else: + for i in range(1, len(src_datasets)): + assert self.feat_dim == src_datasets[i].feat_dim, \ + 'feature dimension does not match across multiple scp files' sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) diff --git a/fairseq/wer.py b/fairseq/wer.py index 00362cf68..ca187e188 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -5,7 +5,7 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -from collections import Counter +from collections import Counter, OrderedDict import speech_tools.utils as speech_utils @@ -13,34 +13,33 @@ class Scorer(object): def __init__(self, dict): self.dict = dict + self.ordered_utt_list = None self.reset() def reset(self): self.char_counter = Counter() self.word_counter = Counter() - self.results = '' - self.aligned_results = '' + self.results = OrderedDict() + self.aligned_results = OrderedDict() - def add_prediction(self, pred, utt_id=None): + def add_prediction(self, utt_id, pred): + if not isinstance(utt_id, str): + raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) - if utt_id is not None and not isinstance(utt_id, str): - raise TypeError('utt_id must be a string(got {}) if not None' - .format(type(utt_id))) pred_words= speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) - if utt_id is not None: - self.results += utt_id + '\n' - self.results += pred_words + '\n' + assert not utt_id in self.results, \ + 'Duplicated utterance id detected: {}'.format(utt_id) + self.results[utt_id] = pred_words + '\n' - def add(self, ref, pred, utt_id=None): + def add_evaluation(self, utt_id, ref, pred): + if not isinstance(utt_id, str): + raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) if not isinstance(ref, str): raise TypeError('ref must be a string (got {})'.format(type(ref))) if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) - if utt_id is not None and not isinstance(utt_id, str): - raise TypeError('utt_id must be a string(got {}) if not None' - .format(type(utt_id))) # char level counts _, _, counter = speech_utils.edit_distance(ref.strip().split(), @@ -48,15 +47,16 @@ def add(self, ref, pred, utt_id=None): self.char_counter += counter # word level counts - ref_words = speech_utils.Tokenizer.tokens_to_sentence(ref, self.dict) + ref_words = speech_utils.Tokenizer.tokens_to_sentence(ref, self.dict, + use_unk_sym=False) pred_words= speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) ref_word_list, pred_word_list = ref_words.split(), pred_words.split() _, steps, counter = speech_utils.edit_distance(ref_word_list, pred_word_list) self.word_counter += counter - if utt_id is not None: - self.aligned_results += utt_id + '\n' - self.aligned_results += speech_utils.aligned_print(ref_word_list, + assert not utt_id in self.aligned_results, \ + 'Duplicated utterance id detected: {}'.format(utt_id) + self.aligned_results[utt_id] = speech_utils.aligned_print(ref_word_list, pred_word_list, steps) def cer(self): @@ -77,18 +77,43 @@ def wer(self): dlt = float(self.word_counter['del']) / self.word_counter['words'] * 100 return wer, sub, ins, dlt - def acc_word_error(self): + def tot_word_error(self): return self.word_counter['sub'] + self.word_counter['ins'] + \ self.word_counter['del'] - def acc_word_count(self): + def tot_word_count(self): return self.word_counter['words'] - @property - def results(self): - return self.results + def add_ordered_utt_list(self, *args): + self.ordered_utt_list = [] + for text_file in args: + with open(text_file, 'r', encoding='utf-8') as f: + one_utt_list = [line.strip().split()[0] for line in f] + self.ordered_utt_list.extend(one_utt_list) + if len(self.results): + assert set(self.ordered_utt_list) == set(self.results.keys()) + if len(self.aligned_results): + assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) + + def print_results(self): + res = '' + if self.order_utt_list is not None: + assert set(self.ordered_utt_list) == set(self.results.keys()) + for utt_id in self.ordered_utt_list: + res += utt_id + ' ' + self.results[utt_id] + else: + for utt_id in self.results: + res += utt_id + ' ' + self.results[utt_id] + return res - @property - def aligned_results(self): - return self.aligned_results + def print_aligned_results(self): + res = '' + if self.order_utt_list is not None: + assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) + for utt_id in self.ordered_utt_list: + res += utt_id + '\n' + self.aligned_results[utt_id] + else: + for utt_id in self.aligned_results: + res += utt_id + '\n' + self.aligned_results[utt_id] + return res diff --git a/speech_recognition.py b/speech_recognition.py index 54d413fa0..d4b6b6527 100644 --- a/speech_recognition.py +++ b/speech_recognition.py @@ -16,6 +16,8 @@ from fairseq import wer, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter from fairseq.speech_recognizer import SpeechRecognizer +from fairseq.utils import import_user_module +from speech_tools.utils import plot_attention def main(args): @@ -23,6 +25,8 @@ def main(args): assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' + import_user_module(args) + if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) @@ -74,7 +78,7 @@ def main(args): print('| The option match_source_len is not applicable to ' 'speech recognition. Ignoring it.') gen_timer = StopwatchMeter() - recognizer = SpeechSequenceGenerator( + recognizer = SpeechRecognizer( models, dict, beam_size=args.beam, minlen=args.min_len, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), @@ -105,9 +109,9 @@ def main(args): has_target = target_tokens is not None target_tokens = target_tokens.int().cpu() if has_target else None - # Regenerate original sentences from tokens. + # Retrieve the original sentences if has_target: - target_str = dict.string(target_tokens, args.remove_bpe) + target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) if not args.quiet: print('T-{}\t{}'.format(utt_id, target_str)) @@ -130,46 +134,61 @@ def main(args): # src_len x tgt_len attention = hypo['attention'].float().cpu() \ if hypo['attention'] is not None else None - scorer.add_prediction(hypo_str, utt_id=utt_id) + if attention is not None and args.print_alignment: + plot_attention(attention, hypo_str, utt_id, + os.path.join(args.path, 'attn_plots')) + print('| Saved attention plots in ' + \ + os.path.join(args.path, 'attn_plots')) + scorer.add_prediction(utt_id, hypo_str) if has_target: - scorer.add(target_str, hypo_str, utt_id=utt_id) + scorer.add_evaluation(utt_id, target_str, hypo_str) num_sentences += 1 print('| Recognized {} utterances in {:.1f}s ({:.2f} utterances/s)'.format( num_sentences, gen_timer.sum, 1. / gen_timer.avg)) + scorer.add_ordered_utt_list(*args.test_text_files) + fn = 'results.txt' with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: - f.write(scorer.results) + f.write(scorer.print_results()) print('| Decoded results saved as ' + f.name) if has_target: - print('| Recognize {} with beam={}: WER={:.2f}%, Sub={:.2f}%, ' - 'Ins={:.2f}%, Del={:.2f}%'.format(args.gen_subset, args.beam, - *(scorer.wer()))) - print('| CER={:.2f}%, Sub={:.2f}%, ' - 'Ins={:.2f}%, Del={:.2f}%'.format(*(scorer.cer()))) - - fn = 'wer.txt' + fn = 'wer' with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: - f.write('WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%\n' - .format(*(scorer.wer()))) + res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( + *(scorer.wer())) + print('| Recognize {} with beam={}: '.format(args.gen_subset, args.beam) + res) + f.write(res + '\n') print('| WER saved in ' + f.name) - fn = 'cer.txt' + fn = 'cer' with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: - f.write('CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%\n' - .format(*(scorer.cer()))) + res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( + *(scorer.cer())) + print('| ' + res) + f.write(res + '\n') print('| CER saved in ' + f.name) fn = 'aligned_results.txt' with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: - f.write(scorer.aligned_results) + f.write(scorer.print_aligned_results()) print('| Aligned results saved as ' + f.name) +def print_options_meaning_changes(args): + """Options that have different meanings than those in the translation task + are explained here. + """ + print('| --max-tokens is the maximum number of input frames in a batch') + if args.print_alignment: + print('| --print-alignment is set to plot attentions') + + if __name__ == '__main__': - parser = options.get_generation_parser() + parser = options.get_generation_parser(default_task='speech_recognition') args = options.parse_args_and_arch(parser) + print_options_meaning_changes(args) main(args) diff --git a/speech_tools/__init__.py b/speech_tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/speech_tools/parse_options.sh b/speech_tools/parse_options.sh new file mode 100755 index 000000000..34476fdb3 --- /dev/null +++ b/speech_tools/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/speech_tools/text2token.py b/speech_tools/text2token.py new file mode 100755 index 000000000..1c90e027a --- /dev/null +++ b/speech_tools/text2token.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2019-present, Yiming Wang +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import argparse +import sys + +from utils import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + description='Convert transcripts into tokens and write them to stdout') + # fmt: off + parser.add_argument('--skip-ncols', default=0, type=int, + help='skip first n columns') + parser.add_argument('--space', default='', type=str, + help='space symbol') + parser.add_argument('--non-lang-syms', default=None, type=str, + help='list of non-linguistic symobles, e.g., etc.') + parser.add_argument('text', type=str, nargs='?', + help='input text') + # fmt: on + + return parser + + +def main(args): + nls = None + if args.non_lang_syms is not None: + with open(args.non_lang_syms, 'r', encoding='utf-8') as f: + nls = [x.rstrip() for x in f.readlines()] + with (open(args.text, 'r', encoding='utf-8') if args.text else sys.stdin) as f: + for line in f: + entry = line.rstrip().split() + tokenized = Tokenizer.tokenize(' '.join(entry[args.skip_ncols:]), + space=args.space, non_lang_syms=nls) + print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 12c6c2d17..2e9132c45 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -5,7 +5,7 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -import re +import os, re import numpy as np from collections import Counter @@ -18,15 +18,17 @@ class Tokenizer: @staticmethod def tokenize(sent, space='', non_lang_syms=None): + assert isinstance(sent, str) sent = ' '.join(sent.strip().split()) match_pos = [] if non_lang_syms is not None: assert isinstance(non_lang_syms, list) - prog = re.compile('|'.join(map(re.escape, non_lang_syms))) - matches = prog.finditer(sent) - for match in matches: - match_pos.append([match.start(), match.end()]) + if len(non_lang_syms) > 0: + prog = re.compile('|'.join(map(re.escape, non_lang_syms))) + matches = prog.finditer(sent) + for match in matches: + match_pos.append([match.start(), match.end()]) tokens = [] i = 0 @@ -43,7 +45,7 @@ def tokenize(sent, space='', non_lang_syms=None): def tokens_to_index_tensor(line, dict, append_eos=True): tokens = line.strip().split() ntokens = len(tokens) - ids = torch.IntTensor(ntokens + 1 if append_eos else ntokens) + ids = torch.LongTensor(ntokens + 1 if append_eos else ntokens) for i, token in enumerate(tokens): ids[i] = dict.index(token) @@ -52,13 +54,15 @@ def tokens_to_index_tensor(line, dict, append_eos=True): return ids @staticmethod - def tokens_to_sentence(line, dict): + def tokens_to_sentence(line, dict, use_unk_sym=True): + # use_unk_sym=False when we want to restore original transcripts from + # token sequences, e.g., obtain reference to compute WER tokens = line.strip().split() sent = "" for token in tokens: if token == dict.space_word: sent += " " - elif dict.index(token) == dict.unk(): + elif use_unk_sym and dict.index(token) == dict.unk(): sent += dict.unk_word elif token != dict.pad_word and token != dict.eos_word: sent += token @@ -112,6 +116,25 @@ def convert_padding_direction(src_frames, src_lengths, right_to_left=False, index = torch.remainder(range + num_pads, max_len) return src_frames.gather(1, index) +def plot_attention(attention, hypo_str, utt_id, save_dir): + """This function plots the attention for an example and save the plot in + save_dir with .pdf as its filename. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + """This function requires matplotlib. + Please install it to generate plots. + If you are on a cluster where you do not have admin rights you could + try using virtualenv.""") + + attn = attention.data.numpy() + plt.matshow(attn) + plt.title(hypo_str) + filename = os.path.join(save_dir, utt_id + '.pdf') + plt.savefig(filename, bbox_inches='tight') + def edit_distance(ref, hyp): """This function is to calculate the edit distance of reference sentence and the hypothesis sentence using dynamic programming, and also backtrace to get @@ -247,7 +270,7 @@ def aligned_print(ref, hyp, steps): else: assert steps[i] == 'del' or steps[i] == 'corr' idx = i - steps[:i].count('ins') - sym = 'D' if step[i] == 'del' else ' ' + sym = 'D' if steps[i] == 'del' else ' ' out_str += sym + ' ' * (len(ref[idx]) - 1) + delim counter = Counter(steps) diff --git a/speech_train.py b/speech_train.py index 61487c7c3..fc256f26d 100644 --- a/speech_train.py +++ b/speech_train.py @@ -21,9 +21,12 @@ from fairseq.data import iterators from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter +from fairseq.utils import import_user_module def main(args): + import_user_module(args) + if args.max_tokens is None: args.max_tokens = 6000 print(args) @@ -57,7 +60,7 @@ def main(args): # Build trainer trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) print('| training on {} GPUs'.format(args.distributed_world_size)) - print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( + print('| max input frames per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) @@ -86,7 +89,7 @@ def main(args): lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() - valid_losses = [None] + valid_losses, valid_wers = [None], [None] valid_subsets = args.valid_subset.split(',') while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update: # train for one epoch @@ -125,6 +128,9 @@ def train(args, trainer, task, epoch_itr): first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): + if hasattr(task, 'iterations_in_epoch'): + task.iterations_in_epoch = i + log_output = trainer.train_step(samples) if log_output is None: continue @@ -227,6 +233,9 @@ def validate(args, trainer, task, epoch_itr, subsets): meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) + if callable(getattr(trainer.criterion, 'set_valid_tgt_dataset', None)): + trainer.criterion.set_valid_tgt_dataset(task.dataset(subset).tgt) + for sample in progress: log_output = trainer.valid_step(sample) @@ -236,7 +245,8 @@ def validate(args, trainer, task, epoch_itr, subsets): continue if k == 'word_error': extra_meters['valid_wer'].update( - v / log_output['word_count'], log_output['word_count']) + float(v) / log_output['word_count'] * 100, + log_output['word_count']) else: extra_meters[k].update(v) @@ -370,9 +380,17 @@ def distributed_main(i, args): main(args) +def print_options_meaning_changes(args): + """Options that have different meanings than those in the translation task + are explained here. + """ + print('| --max-tokens is the maximum number of input frames in a batch') + + if __name__ == '__main__': - parser = options.get_training_parser() + parser = options.get_training_parser(default_task='speech_recognition') args = options.parse_args_and_arch(parser) + print_options_meaning_changes(args) if args.distributed_init_method is None: distributed_utils.infer_init_method(args) diff --git a/tests/test_speech_dataset.py b/tests/test_speech_dataset.py index e59df2ea3..86decab02 100644 --- a/tests/test_speech_dataset.py +++ b/tests/test_speech_dataset.py @@ -28,7 +28,9 @@ def make_dictionary(): alphabet = string.ascii_lowercase for token in alphabet: d.add_symbol(token) + d.add_symbol('') d.finalize(padding_factor=1) # don't add extra padding symbols + d.space_index = d.indices.get('', -1) return d @staticmethod diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py index a13756f1d..9c3a6774b 100644 --- a/tests/test_speech_utils.py +++ b/tests/test_speech_utils.py @@ -26,9 +26,11 @@ def make_dictionary(vocab, non_lang_syms=[]): d = TokenDictionary() for token in vocab: d.add_symbol(token) + d.add_symbol('') for token in non_lang_syms: d.add_symbol(token) d.finalize(padding_factor=1) # don't add extra padding symbols + d.space_index = d.indices.get('', -1) return d @staticmethod From ee70584936d74a9e9b7f14a24322774d5d8c198f Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 26 Jan 2019 05:13:52 -0500 Subject: [PATCH 006/119] code adaptation/changes according to the commits from Jan 24, 2019 to Jan 25, 2019 --- fairseq/models/speech_lstm.py | 33 ++++++++++++++++++++------------- speech_train.py | 13 +++++++++++-- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 74d535b48..c5a0657fe 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -48,6 +48,8 @@ def add_args(parser): help='decoder embedding dimension') parser.add_argument('--decoder-embed-path', type=str, metavar='STR', help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-freeze-embed', action='store_true', + help='freeze decoder embeddings') parser.add_argument('--decoder-hidden-size', type=int, metavar='N', help='decoder hidden size') parser.add_argument('--decoder-layers', type=int, metavar='N', @@ -62,6 +64,9 @@ def add_args(parser): parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion') + parser.add_argument('--share-decoder-input-output-embed', default=False, + action='store_true', + help='share decoder input and output embeddings') # Granular dropout settings (if not specified these default to --dropout) parser.add_argument('--encoder-rnn-dropout-in', type=float, metavar='D', @@ -72,9 +77,6 @@ def add_args(parser): help='dropout probability for decoder input embedding') parser.add_argument('--decoder-dropout-out', type=float, metavar='D', help='dropout probability for decoder output') - parser.add_argument('--share-decoder-input-output-embed', default=False, - action='store_true', - help='share decoder input and output embeddings') # fmt: on @classmethod @@ -107,6 +109,9 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): '--decoder-embed-dim to match --decoder-out-embed-dim' ) + if args.decoder_freeze_embed: + pretrained_decoder_embed.weight.requires_grad = False + def eval_str_nested_list_or_tuple(x, type=int): if x is None: return None @@ -294,8 +299,8 @@ def forward(self, src_tokens, src_lengths): state_size = 2 * self.num_layers, bsz, self.hidden_size else: state_size = self.num_layers, bsz, self.hidden_size - h0 = x.data.new(*state_size).zero_() - c0 = x.data.new(*state_size).zero_() + h0 = x.new_zeros(*state_size) + c0 = x.new_zeros(*state_size) packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) # unpack outputs and apply dropout @@ -393,7 +398,7 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): bsz, seqlen = prev_output_tokens.size() # get outputs from encoder - encoder_outs, _, _ = encoder_out[:3] + encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] srclen = encoder_outs.size(0) # embed tokens @@ -408,15 +413,14 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): if cached_state is not None: prev_hiddens, prev_cells, input_feed = cached_state else: - _, encoder_hiddens, encoder_cells = encoder_out[:3] num_layers = len(self.layers) - prev_hiddens = [x.data.new(bsz, self.hidden_size).zero_() \ + prev_hiddens = [x.new_zeros(bsz, self.hidden_size) \ for i in range(num_layers)] - prev_cells = [x.data.new(bsz, self.hidden_size).zero_() \ + prev_cells = [x.new_zeros(bsz, self.hidden_size) \ for i in range(num_layers)] - input_feed = x.data.new(bsz, self.encoder_output_units).zero_() + input_feed = x.new_zeros(bsz, self.encoder_output_units) - attn_scores = x.data.new(srclen, seqlen, bsz).zero_() + attn_scores = x.new_zeros(srclen, seqlen, bsz) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step @@ -448,7 +452,9 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): # cache previous states (no-op except during incremental generation) utils.set_incremental_state( - self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) + self, incremental_state, 'cached_state', + (prev_hiddens, prev_cells, input_feed), + ) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, -1) @@ -532,6 +538,7 @@ def base_architecture(args): args.encoder_rnn_bidirectional = getattr(args, 'encoder_rnn_bidirectional', True) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 320) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim) args.decoder_layers = getattr(args, 'decoder_layers', 3) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 960) @@ -541,8 +548,8 @@ def base_architecture(args): args.encoder_rnn_dropout_out = getattr(args, 'encoder_rnn_dropout_out', args.dropout) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') + args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) @register_model_architecture('speech_lstm', 'speech_conv_lstm_wsj') def conv_lstm_wsj(args): diff --git a/speech_train.py b/speech_train.py index fc256f26d..77af82ca2 100644 --- a/speech_train.py +++ b/speech_train.py @@ -44,8 +44,12 @@ def main(args): # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) + print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) - print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters()))) + print('| num. model params: {} (num. trained: {})'.format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) # Make a dummy batch to (i) warm the caching allocator and (ii) as a # placeholder DistributedDataParallel when there's an uneven number of @@ -336,7 +340,10 @@ def save_checkpoint(args, trainer, epoch_itr, val_wer): def load_checkpoint(args, trainer, epoch_itr): """Load a checkpoint and replay dataloader to match.""" os.makedirs(args.save_dir, exist_ok=True) - checkpoint_path = os.path.join(args.save_dir, args.restore_file) + if os.path.isabs(args.restore_file): + checkpoint_path = args.restore_file + else: + checkpoint_path = os.path.join(args.save_dir, args.restore_file) if os.path.isfile(checkpoint_path): extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler, eval(args.optimizer_overrides)) @@ -352,6 +359,8 @@ def load_checkpoint(args, trainer, epoch_itr): if 'best' in extra_state: save_checkpoint.best = extra_state['best'] return True + else: + print('| no existing checkpoint found {}'.format(checkpoint_path)) return False From 93c1de4d3a64c67707041bead637d8369f55c222 Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 26 Jan 2019 20:35:22 -0500 Subject: [PATCH 007/119] fix --- examples/asr_wsj/run.sh | 86 ++++++++++++-------- examples/asr_wsj/wer_output_filter | 11 +++ fairseq/criterions/cross_entropy_with_wer.py | 73 ++++++++++------- fairseq/data/token_dictionary.py | 27 +++++- fairseq/models/speech_lstm.py | 3 +- fairseq/tasks/speech_recognition.py | 20 +++-- fairseq/wer.py | 53 ++++++++++-- speech_recognition.py | 45 ++++++---- speech_tools/utils.py | 9 +- speech_train.py | 10 ++- 10 files changed, 236 insertions(+), 101 deletions(-) create mode 100755 examples/asr_wsj/wer_output_filter diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 1af8af9bf..adc7b3457 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -17,6 +17,7 @@ train_set=train_si284 valid_set=test_dev93 test_set=test_eval92 checkpoint=checkpoint_best.pt +validate_on_train=false if [ -f ./path.sh ]; then @@ -30,14 +31,14 @@ else . ./speech_tools/parse_options.sh fi +valid_subset=valid +if $validate_on_train; then + valid_subset="$valid_subset train" +fi + dict=$data_dir/lang/${train_set}_units.txt nlsyms=$data_dir/lang/non_lang_syms.txt train_text=$data_dir/$train_set/text -train_token_text=$data_dir/$train_set/token_text -valid_text=$data_dir/$valid_set/text -valid_token_text=$data_dir/$valid_set/token_text -test_text=$data_dir/$test_set/text -test_token_text=$data_dir/$test_set/token_text if [ ${stage} -le 1 ]; then echo "Stage 1: Dictionary Preparation and Text Tokenization" mkdir -p $data_dir/lang @@ -46,49 +47,66 @@ if [ ${stage} -le 1 ]; then cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "<" > $nlsyms cat $nlsyms - echo "Making a dictionary and tokenizing text for training set..." - python3 speech_tools/text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms \ - $train_text > $train_token_text - cut -f 2- -d" " $train_token_text | tr " " "\n" | grep -v -e '^\s*$' | sort | \ - uniq -c | awk '{print $2,$1}' > $dict - wc -l $dict - - echo "Tokenizing text for validation and test set..." - python3 speech_tools/text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms \ - $valid_text > $valid_token_text - python3 speech_tools/text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms \ - $test_text > $test_token_text + echo "Making a dictionary and tokenizing text for train/valid/test set..." + for dataset in $train_set $valid_set $test_set; do + text=$data_dir/$dataset/text + token_text=$data_dir/$dataset/token_text + python3 speech_tools/text2token.py --skip-ncols 1 --space "" \ + --non-lang-syms $nlsyms $text > $token_text + if [ "$dataset" == "$train_set" ]; then + cut -f 2- -d" " $token_text | tr " " "\n" | grep -v -e '^\s*$' | sort | \ + uniq -c | awk '{print $2,$1}' > $dict + wc -l $dict + fi + done fi train_feat=$data_dir/dump/$train_set/deltafalse/feats.scp +train_token_text=$data_dir/$train_set/token_text valid_feat=$data_dir/dump/$valid_set/deltafalse/feats.scp +valid_token_text=$data_dir/$valid_set/token_text if [ ${stage} -le 2 ]; then echo "Stage 2: Model Training" mkdir -p $exp_dir/logs + log_file=$exp_dir/logs/train.log + [ -f $exp_dir/checkpoint_last.pt ] && log_file="-a $log_file" + opts="" + [ -f examples/asr_wsj/wer_output_filter ] && \ + opts="$opts --wer-output-filter examples/asr_wsj/wer_output_filter" [ -z "$free_gpu" ] && free_gpu=$(free-gpu) - CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_train.py \ - --log-interval 1000 --log-format "simple" --seed 1 \ - --num-workers 0 --max-tokens 45000 --max-sentences 32 --max-sentences-valid 64 \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_train.py --seed 1 \ + --log-interval 500 --log-format "simple" --print-training-sample-interval 500 \ + --num-workers 0 --max-tokens 24000 --max-sentences 32 \ + --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ - --max-epoch 20 --optimizer "adam" --lr 0.1 --weight-decay 0.0 \ - --lr-scheduler "reduce_lr_on_plateau" --lr-shrink 0.1 \ - --save-dir $exp_dir --save-interval-updates 100 --keep-interval-updates 10 \ - --keep-last-epochs 5 --validate-interval 100 \ + --max-epoch 20 --optimizer "adam" --lr 0.001 --weight-decay 0.0 \ + --lr-scheduler "reduce_lr_on_plateau" --lr-shrink 0.1 --min-lr "1e-15" \ + --save-dir $exp_dir --save-interval-updates 200 --keep-interval-updates 10 \ + --keep-last-epochs 5 --validate-interval 1 \ --arch "speech_conv_lstm_wsj" --criterion "cross_entropy_with_wer" \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ - --dict $dict --max-source-positions 9999 --max-target-positions 999 2>&1 | tee $exp_dir/logs/train.log -fi + --dict $dict --non-lang-syms $nlsyms \ + --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file exit 0 +fi -test_feat=$data_dir/dump/$test_set/deltafalse/feats.scp -if [ ${stage} -le 2 ]; then +if [ ${stage} -le 3 ]; then echo "Stage 3: Decoding" + opts="" + [ -f examples/asr_wsj/wer_output_filter ] && \ + opts="$opts --wer-output-filter examples/asr_wsj/wer_output_filter" [ -z "$free_gpu" ] && free_gpu=$(free-gpu) - CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_recognition.py \ - --max-tokens 45000 --max-sentences 32 --num-shards 1 --shard-id 0 \ - --test-feat-files $test_feat --test-text-files $test_token_text \ - --dict $dict --max-source-positions 9999 --max-target-positions 999 \ - --path $exp_dir/$checkpoint --beam 10 --max-len-a 0.5 --max-len-b 0 \ - --lenpen 1.0 --print-alignment 2>&1 | tee $exp_dir/logs/decode.log + for dataset in $valid_set $test_set; do + feat=$data_dir/dump/$dataset/deltafalse/feats.scp + text=$data_dir/$dataset/token_text + CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_recognition.py \ + --max-tokens 45000 --max-sentences 32 --num-shards 1 --shard-id 0 \ + --test-feat-files $feat --test-text-files $text \ + --dict $dict --non-lang-syms $nlsyms \ + --max-source-positions 9999 --max-target-positions 999 \ + --path $exp_dir/$checkpoint --beam 10 --max-len-a 0.5 --max-len-b 0 \ + --lenpen 1.0 --output-dir $exp_dir/decode_$dataset --print-alignment $opts \ + 2>&1 | tee $exp_dir/logs/decode_$dataset.log + done fi diff --git a/examples/asr_wsj/wer_output_filter b/examples/asr_wsj/wer_output_filter new file mode 100755 index 000000000..939cd0872 --- /dev/null +++ b/examples/asr_wsj/wer_output_filter @@ -0,0 +1,11 @@ +#!/bin/sed -f +s:::g +s:::g +s:::g +s/://g +s/\*//g +s/-HOLDER/HOLDER/g +s/COMPAIGN/CAMPAIGN/g +s/APPROACHES-/APPROACHES/g +s/RESEACHERS/RESEARCHERS/g + diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 4aac9efea..c3cfa0305 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -12,6 +12,8 @@ from fairseq import utils, wer from fairseq.data import data_utils +from speech_tools.utils import Tokenizer + from . import FairseqCriterion, register_criterion from .cross_entropy import CrossEntropyCriterion @@ -23,22 +25,26 @@ def __init__(self, args, task): super().__init__(args, task) dict = task.dict if hasattr(task, 'dict') else getattr(task, 'tgt_dict') - self.scorer = wer.Scorer(dict) - self.num_calls = 0 #getattr(task, 'iterations_in_epoch', 0) + self.scorer = wer.Scorer(dict, + wer_output_filter=task.args.wer_output_filter) self.train_tgt_dataset = task.dataset(args.train_subset).tgt self.valid_tgt_dataset = None + self.num_updates = -1 @staticmethod def add_args(parser): """Add criterion-specific arguments to the parser.""" # fmt: off - parser.add_argument('--sample-results-interval', type=int, metavar='N', default=500, - help='print sample results interval every this ' - 'number of forward times') + parser.add_argument('--print-training-sample-interval', type=int, + metavar='N', dest='print_interval', default=500, + help='print a training sample (reference + ' + 'prediction) every this number of updates') # fmt: on def forward(self, model, sample, reduce=True): - """Compute the loss for the given sample. + """Compute the loss for the given sample; periodically print out + randomly sampled predictions if model is in training mode, otherwise + aggregate word error stats for validation. Returns a tuple with three elements: 1) the loss @@ -47,12 +53,11 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample['net_input']) lprobs = model.get_normalized_probs(net_output, log_probs=True) - # wer stats code starts - if not model.training or \ - self.num_calls % self.args.sample_results_interval == 0: - pred = lprobs.argmax(-1).long().cpu() # bsz x len - target = sample['target'].long().cpu() # bsz x len - assert pred.size() == sample['net_input']['prev_output_tokens'].size() + target = model.get_targets(sample, net_output) + # word error stats code starts + if not model.training or (self.num_updates // self.args.print_interval > + (self.num_updates - 1) // self.args.print_interval): + pred = lprobs.argmax(-1).cpu() # bsz x len assert pred.size() == target.size() def lengths_strip_padding(idx_array, padding_idx): @@ -66,7 +71,7 @@ def lengths_strip_padding(idx_array, padding_idx): return len(idx_array) return [lengths_strip_padding(row, padding_idx) for row in idx_array] - target_lengths = lengths_strip_padding(target.numpy(), + target_lengths = lengths_strip_padding(target.data.cpu().numpy(), self.padding_idx) dict = self.scorer.dict if not model.training: # validation step, compute WER stats with scorer @@ -74,26 +79,24 @@ def lengths_strip_padding(idx_array, padding_idx): for i, length in enumerate(target_lengths): utt_id = sample['utt_id'][i] id = sample['id'].data[i] - #ref_str = dict.string(target.data[i]) - ref_str = self.valid_tgt_dataset.get_original_tokens(id) - pred_str = dict.string(pred.data[i][:length]) - self.scorer.add_evaluation(utt_id, ref_str, pred_str) - else: # print a randomly sampled result every sample_results_interval batch - with data_utils.numpy_seed(self.num_calls): + #ref_tokens = dict.string(target.data[i]) + ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) + pred_tokens = dict.string(pred.data[i][:length]) + self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) + else: # print a randomly sampled result every print_interval updates + with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i] - #ref_str_one = dict.string(target.data[i]) - ref_str_one = self.train_tgt_dataset.get_original_tokens(id) - pred_str_one = dict.string(pred.data[i][:target_lengths[i]]) - print('| ' + sample['utt_id'][i]) - print('| sample REF: ' + ref_str_one) - print('| sample PRD: ' + pred_str_one) - self.num_calls += 1 - # wer stats code ends + #ref_one = Tokenizer.tokens_to_sentence(dict.string(target.data[i]), dict) + ref_one = self.train_tgt_dataset.get_original_text(id, dict) + pred_one = Tokenizer.tokens_to_sentence( + dict.string(pred.data[i][:target_lengths[i]]), dict) + print('| sample REF: ' + ref_one) + print('| sample PRD: ' + pred_one) + # word error stats code ends lprobs = lprobs.view(-1, lprobs.size(-1)) - target = model.get_targets(sample, net_output).view(-1) - loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, - reduce=reduce) + loss = F.nll_loss(lprobs, target.view(-1), ignore_index=self.padding_idx, + reduction='sum' if reduce else 'none') sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, @@ -104,6 +107,8 @@ def lengths_strip_padding(idx_array, padding_idx): if not model.training: # do not compute word error in training mode logging_output['word_error'] = self.scorer.tot_word_error() logging_output['word_count'] = self.scorer.tot_word_count() + logging_output['char_error'] = self.scorer.tot_char_error() + logging_output['char_count'] = self.scorer.tot_char_count() return loss, sample_size, logging_output @staticmethod @@ -112,10 +117,18 @@ def aggregate_logging_outputs(logging_outputs): agg_output = CrossEntropyCriterion.aggregate_logging_outputs(logging_outputs) word_error = sum(log.get('word_error', 0) for log in logging_outputs) word_count = sum(log.get('word_count', 0) for log in logging_outputs) + char_error = sum(log.get('char_error', 0) for log in logging_outputs) + char_count = sum(log.get('char_count', 0) for log in logging_outputs) if word_count > 0: # model.training == False agg_output['word_error'] = word_error agg_output['word_count'] = word_count + if char_count > 0: # model.training == False + agg_output['char_error'] = char_error + agg_output['char_count'] = char_count return agg_output def set_valid_tgt_dataset(self, dataset): self.valid_tgt_dataset = dataset + + def set_num_updates(self, num_updates): + self.num_updates = num_updates diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 514b406bd..0c44bcc2b 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -22,6 +22,7 @@ def __init__(self, pad='', eos='', unk='', space=''): self.eos_index = self.add_symbol(eos) self.unk_index = self.add_symbol(unk) self.nspecial = len(self.symbols) + self.non_lang_syms = None def string(self, tensor, bpe_symbol=None, escape_unk=False): """Helper for converting a tensor of token indices to a string. @@ -50,7 +51,7 @@ def space(self): return self.space_index @classmethod - def load(cls, f, ignore_utf_errors=False): + def load(cls, f, f_non_lang_syms=None, ignore_utf_errors=False): """Loads the dictionary from a text file with the format: ``` @@ -58,11 +59,33 @@ def load(cls, f, ignore_utf_errors=False): ... ``` - and identifies the space symbol if it exists, by obtaining its index + + Identifies the space symbol if it exists, by obtaining its index (space_index=-1 if no space symbol) + + Loads non_lang_syms from another text file, if it exists, with one + symbol per line """ d = super().load(f, ignore_utf_errors) d.space_index = d.indices.get(d.space_word, -1) + + if f_non_lang_syms is not None: + assert isinstance(f_non_lang_syms, str) + try: + with open(f_non_lang_syms, 'r', encoding='utf-8', + errors='ignore' if ignore_utf_errors else None) as fd: + non_lang_syms = [x.rstrip() for x in fd.readlines()] + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception("Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f)) + + for sym in non_lang_syms: + assert d.index(sym) != d.unk(), \ + '{} in {} is not in the dictionary'.format(sym, f_non_lang_syms) + d.non_lang_syms = non_lang_syms + return d def dummy_sentence(self, length): diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index c5a0657fe..1d6ba4bf3 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -139,6 +139,7 @@ def eval_str_nested_list_or_tuple(x, type=int): conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, in_channels=in_channels) if not out_channels is None else None + print('| input feature dimension: {}'.format(task.feat_dim)) assert task.feat_dim % in_channels == 0 rnn_encoder_input_size = task.feat_dim // in_channels if conv_layers is not None: @@ -398,7 +399,7 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): bsz, seqlen = prev_output_tokens.size() # get outputs from encoder - encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3] + encoder_outs = encoder_out[0] srclen = encoder_outs.size(0) # embed tokens diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 23e21fae2..b4b56d002 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -25,7 +25,7 @@ @register_task('speech_recognition') class SpeechRecognitionTask(FairseqTask): """ - Translate from speech (source) to token text (target). + Transcribe from speech (source) to token text (target). Args: dict (Dictionary): dictionary for the output tokens @@ -35,7 +35,7 @@ class SpeechRecognitionTask(FairseqTask): The speech recognition task is compatible with :mod:`train.py `, :mod:`generate.py ` and :mod:`interactive.py `. - The speech_recognition task provides the following additional command-line + The speech recognition task provides the following additional command-line arguments: .. argparse:: @@ -46,6 +46,7 @@ class SpeechRecognitionTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" + # fmt: off parser.add_argument('--train-feat-files', nargs='+', help='path(s) to scp feature file(s) for training') parser.add_argument('--train-text-files', nargs='+', @@ -63,6 +64,11 @@ def add_args(parser): 'each should matches with one in --test-feat-files') parser.add_argument('--dict', default=None, type=str, help='path to the dictionary') + parser.add_argument('--non-lang-syms', default=None, type=str, + help='list of non-linguistic symbols, e.g., ' + 'etc. To be filtered out when calculating WER/CER') + parser.add_argument('--wer-output-filter', default=None, type=str, + help='path to wer_output_filter file for WER evaluation') parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', help='pad the source on the left') parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', @@ -73,14 +79,16 @@ def add_args(parser): help='max number of tokens in the target sequence') parser.add_argument('--upsample-primary', default=1, type=int, help='amount to upsample primary dataset') + # fmt: off @staticmethod - def load_pretrained_model(path, dict_path, arg_overrides=None): + def load_pretrained_model(path, dict_path, non_lang_syms=None, + arg_overrides=None): model = utils.load_checkpoint_to_cpu(path) args = model['args'] state_dict = model['model'] args = utils.override_model_args(args, arg_overrides) - dict = TokenDictionary.load(dict_path) + dict = TokenDictionary.load(dict_path, f_non_lang_syms=non_lang_syms) task = SpeechRecognitionTask(args, dict) model = task.build_model(args) @@ -91,7 +99,6 @@ def load_pretrained_model(path, dict_path, arg_overrides=None): def __init__(self, args, dict): super().__init__(args) self.dict = dict - self.iterations_in_epoch = 0 @classmethod def setup_task(cls, args, **kwargs): @@ -106,7 +113,8 @@ def setup_task(cls, args, **kwargs): # load dictionaries dict_path = os.path.join(os.path.dirname(args.text_files[0]), 'dict.txt') if args.dict is None else args.dict - dict = TokenDictionary.load(dict_path) + dict = TokenDictionary.load(dict_path, + f_non_lang_syms=args.non_lang_syms) print('| dictionary: {} types'.format(len(dict))) return cls(args, dict) diff --git a/fairseq/wer.py b/fairseq/wer.py index ca187e188..bc9782ea7 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -5,15 +5,19 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import re + from collections import Counter, OrderedDict import speech_tools.utils as speech_utils class Scorer(object): - def __init__(self, dict): + def __init__(self, dict, wer_output_filter=None): self.dict = dict self.ordered_utt_list = None + self.word_filters = [] + self.parse_wer_output_filter(wer_output_filter) self.reset() def reset(self): @@ -22,13 +26,31 @@ def reset(self): self.results = OrderedDict() self.aligned_results = OrderedDict() + def parse_wer_output_filter(self, wer_output_filter): + if wer_output_filter: + with open(wer_output_filter, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line.startswith('#!') or line == '': + continue + elif line.startswith('s/'): + m = re.match(r's/(\S+)/(\w*)/g', line) + assert m is not None + self.word_filters.append([m.group(1), m.group(2)]) + elif line.startswith('s:'): + m = re.match(r's:(\S+):(\w*):g', line) + assert m is not None + self.word_filters.append([m.group(1), m.group(2)]) + else: + print('Unsupported pattern: "' + line + '", ignored') + def add_prediction(self, utt_id, pred): if not isinstance(utt_id, str): raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) - pred_words= speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) + pred_words = speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) assert not utt_id in self.results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.results[utt_id] = pred_words + '\n' @@ -41,6 +63,14 @@ def add_evaluation(self, utt_id, ref, pred): if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) + # filter out any non_lang_syms from ref and pred + non_lang_syms = getattr(self.dict, 'non_lang_syms', None) + assert non_lang_syms is None or isinstance(non_lang_syms, list) + if non_lang_syms is not None and len(non_lang_syms) > 0: + ref_list, pred_list = ref.strip().split(), pred.strip().split() + ref = ' '.join([x for x in ref_list if x not in non_lang_syms]) + pred = ' '.join([x for x in pred_list if x not in non_lang_syms]) + # char level counts _, _, counter = speech_utils.edit_distance(ref.strip().split(), pred.strip().split()) @@ -49,7 +79,13 @@ def add_evaluation(self, utt_id, ref, pred): # word level counts ref_words = speech_utils.Tokenizer.tokens_to_sentence(ref, self.dict, use_unk_sym=False) - pred_words= speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) + pred_words = speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) + + # filter words according to self.word_filters (support re.sub only) + for pattern, repl in self.word_filters: + ref_words = re.sub(pattern, repl, ref_words) + pred_words = re.sub(pattern, repl, pred_words) + ref_word_list, pred_word_list = ref_words.split(), pred_words.split() _, steps, counter = speech_utils.edit_distance(ref_word_list, pred_word_list) @@ -84,6 +120,13 @@ def tot_word_error(self): def tot_word_count(self): return self.word_counter['words'] + def tot_char_error(self): + return self.char_counter['sub'] + self.char_counter['ins'] + \ + self.char_counter['del'] + + def tot_char_count(self): + return self.char_counter['words'] + def add_ordered_utt_list(self, *args): self.ordered_utt_list = [] for text_file in args: @@ -97,7 +140,7 @@ def add_ordered_utt_list(self, *args): def print_results(self): res = '' - if self.order_utt_list is not None: + if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.results.keys()) for utt_id in self.ordered_utt_list: res += utt_id + ' ' + self.results[utt_id] @@ -108,7 +151,7 @@ def print_results(self): def print_aligned_results(self): res = '' - if self.order_utt_list is not None: + if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) for utt_id in self.ordered_utt_list: res += utt_id + '\n' + self.aligned_results[utt_id] diff --git a/speech_recognition.py b/speech_recognition.py index d4b6b6527..794945dd7 100644 --- a/speech_recognition.py +++ b/speech_recognition.py @@ -14,10 +14,10 @@ import torch from fairseq import wer, options, progress_bar, tasks, utils -from fairseq.meters import StopwatchMeter +from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.speech_recognizer import SpeechRecognizer from fairseq.utils import import_user_module -from speech_tools.utils import plot_attention +from speech_tools.utils import Tokenizer, plot_attention def main(args): @@ -36,7 +36,7 @@ def main(args): # Load dataset split task = tasks.setup_task(args) task.load_dataset(args.gen_subset) - print('| {} {} {} examples'.format(args.data, args.gen_subset, + print('| {} {} examples'.format(args.gen_subset, len(task.dataset(args.gen_subset)))) # Set dictionary @@ -94,7 +94,7 @@ def main(args): recognizer.cuda() # Generate and compute WER - scorer = wer.Scorer(dict) + scorer = wer.Scorer(dict, wer_output_filter=args.wer_output_filter) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: @@ -113,14 +113,19 @@ def main(args): if has_target: target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) if not args.quiet: - print('T-{}\t{}'.format(utt_id, target_str)) + target_sent = Tokenizer.tokens_to_sentence(target_str, dict, + use_unk_sym=False) + print('T-{}\t{}'.format(utt_id, target_sent)) # Process top predictions for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): - hypo_str = dict.string(hypo['tokens'].int().cpu(), remove_bpe) + hypo_str = dict.string(hypo['tokens'].int().cpu(), args.remove_bpe) + if not args.quiet or i == 0: + hypo_sent = Tokenizer.tokens_to_sentence(hypo_str, dict) if not args.quiet: - print('H-{}\t{}\t{}'.format(utt_id, hypo['score'], hypo_str)) + print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) + ''' print('P-{}\t{}'.format( utt_id, ' '.join(map( @@ -128,17 +133,17 @@ def main(args): hypo['positional_scores'].tolist(), )) )) + ''' # Score and obtain attention only the top hypothesis if i == 0: # src_len x tgt_len attention = hypo['attention'].float().cpu() \ if hypo['attention'] is not None else None - if attention is not None and args.print_alignment: - plot_attention(attention, hypo_str, utt_id, - os.path.join(args.path, 'attn_plots')) - print('| Saved attention plots in ' + \ - os.path.join(args.path, 'attn_plots')) + if attention is not None: + save_dir = os.path.join(args.output_dir, 'attn_plots') + os.makedirs(save_dir, exist_ok=True) + plot_attention(attention, hypo_sent, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str) if has_target: scorer.add_evaluation(utt_id, target_str, hypo_str) @@ -147,17 +152,21 @@ def main(args): print('| Recognized {} utterances in {:.1f}s ({:.2f} utterances/s)'.format( num_sentences, gen_timer.sum, 1. / gen_timer.avg)) + if args.print_alignment: + print('| Saved attention plots in ' + save_dir) scorer.add_ordered_utt_list(*args.test_text_files) - fn = 'results.txt' - with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + os.makedirs(args.output_dir, exist_ok=True) + + fn = 'decoded_results.txt' + with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_results()) print('| Decoded results saved as ' + f.name) if has_target: fn = 'wer' - with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.wer())) print('| Recognize {} with beam={}: '.format(args.gen_subset, args.beam) + res) @@ -165,7 +174,7 @@ def main(args): print('| WER saved in ' + f.name) fn = 'cer' - with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.cer())) print('| ' + res) @@ -173,7 +182,7 @@ def main(args): print('| CER saved in ' + f.name) fn = 'aligned_results.txt' - with open(os.path.join(args.path, fn), 'w', encoding='utf-8') as f: + with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_aligned_results()) print('| Aligned results saved as ' + f.name) @@ -189,6 +198,8 @@ def print_options_meaning_changes(args): if __name__ == '__main__': parser = options.get_generation_parser(default_task='speech_recognition') + parser.add_argument('--output-dir', metavar='DIR', required=True, + help='path to output results') args = options.parse_args_and_arch(parser) print_options_meaning_changes(args) main(args) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 2e9132c45..de847008d 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -116,24 +116,27 @@ def convert_padding_direction(src_frames, src_lengths, right_to_left=False, index = torch.remainder(range + num_pads, max_len) return src_frames.gather(1, index) -def plot_attention(attention, hypo_str, utt_id, save_dir): +def plot_attention(attention, hypo_sent, utt_id, save_dir): """This function plots the attention for an example and save the plot in save_dir with .pdf as its filename. """ try: + import matplotlib as mpl + mpl.use('Agg') import matplotlib.pyplot as plt except ImportError: raise ImportError( """This function requires matplotlib. - Please install it to generate plots. + Please install it to generate plots, or unset --print-alignment. If you are on a cluster where you do not have admin rights you could try using virtualenv.""") attn = attention.data.numpy() plt.matshow(attn) - plt.title(hypo_str) + plt.title(hypo_sent, fontsize=8) filename = os.path.join(save_dir, utt_id + '.pdf') plt.savefig(filename, bbox_inches='tight') + plt.close() def edit_distance(ref, hyp): """This function is to calculate the edit distance of reference sentence and diff --git a/speech_train.py b/speech_train.py index 77af82ca2..be0828ee0 100644 --- a/speech_train.py +++ b/speech_train.py @@ -132,8 +132,8 @@ def train(args, trainer, task, epoch_itr): first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): - if hasattr(task, 'iterations_in_epoch'): - task.iterations_in_epoch = i + if callable(getattr(trainer.criterion, 'set_num_updates', None)): + trainer.criterion.set_num_updates(trainer.get_num_updates()) log_output = trainer.train_step(samples) if log_output is None: @@ -245,12 +245,16 @@ def validate(args, trainer, task, epoch_itr, subsets): for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', - 'sample_size', 'word_count']: + 'sample_size', 'word_count', 'char_count']: continue if k == 'word_error': extra_meters['valid_wer'].update( float(v) / log_output['word_count'] * 100, log_output['word_count']) + elif k == 'char_error': + extra_meters['valid_cer'].update( + float(v) / log_output['char_count'] * 100, + log_output['char_count']) else: extra_meters[k].update(v) From e864565807b46229fe5b9ea996dab605bac69f60 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 7 Feb 2019 03:11:13 -0500 Subject: [PATCH 008/119] validation on wer --- fairseq/criterions/cross_entropy_with_wer.py | 101 ++++++++++++++----- fairseq/models/speech_lstm.py | 8 +- speech_train.py | 4 +- 3 files changed, 86 insertions(+), 27 deletions(-) diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index c3cfa0305..30f155e56 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -7,10 +7,12 @@ import math import numpy as np +import torch import torch.nn.functional as F from fairseq import utils, wer from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder from speech_tools.utils import Tokenizer @@ -51,46 +53,84 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input']) - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) + dict = self.scorer.dict + if model.training: + net_output = model(**sample['net_input']) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + else: + assert isinstance(model.decoder, FairseqIncrementalDecoder) + incremental_states = {} + encoder_input = { + k: v for k, v in sample['net_input'].items() + if k != 'prev_output_tokens' + } + encoder_out = model.encoder(**encoder_input) + target = sample['target'] + # make the maximum decoding length equal to at least the length of + # target, and the length of encoder_out if possible + # and at least the length of target + maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) + tokens = target.new_full([target.size(0), maxlen + 2], dict.pad()) + tokens[:, 0] = dict.eos() + lprobs = [] + attn = [] if model.decoder.need_attn else None + dummy_log_probs = encoder_out['encoder_out'][0].new_full( + [target.size(0), len(dict)], -np.log(len(dict))) + for step in range(maxlen + 1): # one extra step for EOS marker + is_eos = tokens[:, step].eq(dict.eos()) + # if all predictions are finished (i.e., ended with eos), + # pad lprobs to target length with dummy log probs, + # truncate tokens up to this step and break + if step > 0 and is_eos.sum() == is_eos.size(0): + for _ in range(step, target.size(1)): + lprobs.append(dummy_log_probs) + tokens = tokens[:, :step + 1] + break + log_probs, attn_scores = self._decode(tokens[:, :step + 1], + model, encoder_out, incremental_states) + log_probs[:, dict.pad()] = -math.inf # never select pad + tokens[:, step + 1] = log_probs.argmax(-1) + if step > 0: # deal with finished predictions + # make log_probs uniform if the previous output token is EOS + # and add consecutive EOS to the end of prediction + log_probs[is_eos, :] = -np.log(log_probs.size(1)) + tokens[is_eos, step + 1] = dict.eos() + if step < target.size(1): + lprobs.append(log_probs) + if model.decoder.need_attn: + attn.append(attn_scores) + # bsz x min(tgtlen, maxlen + 1) x vocab_size + lprobs = torch.stack(lprobs, dim=1) + if model.decoder.need_attn: + # bsz x (maxlen + 1) x (length of encoder_out) + attn = torch.stack(attn, dim=1) # word error stats code starts if not model.training or (self.num_updates // self.args.print_interval > (self.num_updates - 1) // self.args.print_interval): - pred = lprobs.argmax(-1).cpu() # bsz x len - assert pred.size() == target.size() - - def lengths_strip_padding(idx_array, padding_idx): - # assume sequences are right-padded, so work out the length by - # looking for the first occurence of padding_idx - assert idx_array.ndim == 1 or idx_array.ndim == 2 - if idx_array.ndim == 1: - try: - return idx_array.tolist().index(padding_idx) - except ValueError: - return len(idx_array) - return [lengths_strip_padding(row, padding_idx) for row in idx_array] - - target_lengths = lengths_strip_padding(target.data.cpu().numpy(), - self.padding_idx) - dict = self.scorer.dict + pred = lprobs.argmax(-1).cpu() if model.training else \ + tokens[:, 1:].data.cpu() # bsz x len + if not model.training: # validation step, compute WER stats with scorer + assert pred.size(0) == target.size(0) self.scorer.reset() - for i, length in enumerate(target_lengths): + for i in range(target.size(0)): utt_id = sample['utt_id'][i] id = sample['id'].data[i] #ref_tokens = dict.string(target.data[i]) ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dict.string(pred.data[i][:length]) + pred_tokens = dict.string(pred.data[i]) self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) else: # print a randomly sampled result every print_interval updates + assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i] + length = utils.strip_pad(target.data[i], self.padding_idx).size(0) #ref_one = Tokenizer.tokens_to_sentence(dict.string(target.data[i]), dict) ref_one = self.train_tgt_dataset.get_original_text(id, dict) pred_one = Tokenizer.tokens_to_sentence( - dict.string(pred.data[i][:target_lengths[i]]), dict) + dict.string(pred.data[i][:length]), dict) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends @@ -127,6 +167,21 @@ def aggregate_logging_outputs(logging_outputs): agg_output['char_count'] = char_count return agg_output + def _decode(self, tokens, model, encoder_out, incremental_states): + with torch.no_grad(): + decoder_out = list(model.decoder(tokens, encoder_out, + incremental_state=incremental_states)) + decoder_out[0] = decoder_out[0][:, -1, :] + attn = decoder_out[1] + if type(attn) is dict: + attn = attn['attn'] + if attn is not None: + if type(attn) is dict: + attn = attn['attn'] + attn = attn[:, -1, :] + probs = model.get_normalized_probs(decoder_out, log_probs=True) + return probs, attn + def set_valid_tgt_dataset(self, dataset): self.valid_tgt_dataset = dataset diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 1d6ba4bf3..c8cc7b80b 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -61,6 +61,8 @@ def add_args(parser): help='attention type') parser.add_argument('--attention-dim', type=int, metavar='N', help='attention dimension') + parser.add_argument('--need-attention', default=False, action='store_true', + help='need to return attention tensor for the caller') parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion') @@ -173,6 +175,7 @@ def eval_str_nested_list_or_tuple(x, type=int): encoder_output_units=encoder.output_units, attn_type=args.attention_type, attn_dim=args.attention_dim, + need_attn=args.need_attention, pretrained_embed=pretrained_decoder_embed, share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( @@ -346,7 +349,7 @@ def __init__( self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, encoder_output_units=512, attn_type='bahdanau', attn_dim=256, - pretrained_embed=None, share_input_output_embed=False, + need_attn=False, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, ): super().__init__(dictionary) @@ -354,7 +357,7 @@ def __init__( self.dropout_out = dropout_out self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed - self.need_attn = True + self.need_attn = need_attn self.adaptive_softmax = None num_embeddings = len(dictionary) @@ -545,6 +548,7 @@ def base_architecture(args): args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 960) args.attention_type = getattr(args, 'attention_type', 'bahdanau') args.attention_dim = getattr(args, 'attention_dim', 320) + args.need_attention = getattr(args, 'need_attention', False) args.encoder_rnn_dropout_in = getattr(args, 'encoder_rnn_dropout_in', args.dropout) args.encoder_rnn_dropout_out = getattr(args, 'encoder_rnn_dropout_out', args.dropout) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) diff --git a/speech_train.py b/speech_train.py index be0828ee0..4b72bb229 100644 --- a/speech_train.py +++ b/speech_train.py @@ -102,8 +102,8 @@ def main(args): if epoch_itr.epoch % args.validate_interval == 0: valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, valid_subsets) - # only use first validation loss to update the learning rate - lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) + # only use first validation wer to update the learning rate + lr = trainer.lr_step(epoch_itr.epoch, valid_wers[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: From 8782e5d581a253022c7333132feeb1893ab2eaf6 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 8 Feb 2019 00:05:55 -0500 Subject: [PATCH 009/119] environment configurations --- .gitignore | 3 +++ examples/asr_wsj/path.sh | 13 ++++++++++--- speech_tools/Makefile | 18 ++++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 speech_tools/Makefile diff --git a/.gitignore b/.gitignore index fbe71542a..4cfe9214b 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,6 @@ wandb/ # emacs saves *~ + +# speech related +speech_tools/kaldi diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh index 731d834f6..621112944 100644 --- a/examples/asr_wsj/path.sh +++ b/examples/asr_wsj/path.sh @@ -1,7 +1,14 @@ -export PATH=~/anaconda3/bin:$PATH - MAIN_ROOT=$PWD/../.. +KALDI_ROOT=$MAIN_ROOT/speech_tools/kaldi -export PATH=$MAIN_ROOT:$PATH +# BEGIN from kaldi path.sh +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh export LC_ALL=C +# END + +export PATH=~/anaconda3/bin:$PATH +export PATH=$MAIN_ROOT:$PATH diff --git a/speech_tools/Makefile b/speech_tools/Makefile new file mode 100644 index 000000000..bc98fca9e --- /dev/null +++ b/speech_tools/Makefile @@ -0,0 +1,18 @@ +KALDI = + +.PHONY: all clean + +all: kaldi + +ifneq ($(strip $(KALDI)),) +kaldi: + ln -s $(KALDI) kaldi +else +kaldi: + # git clone https://github.com/kaldi-asr/kaldi.git kaldi_github; cd kaldi_github/tools; $(MAKE) all + # cd kaldi_github/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all + # ln -nfs kaldi_github kaldi +endif + +clean: + rm -fr kaldi From 9363a71f3924c6ad3e57c01f6bfae1643caabfe1 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 8 Feb 2019 00:08:24 -0500 Subject: [PATCH 010/119] code adaptation/changes according to the commits from Jan 29, 2019 to Feb 9, 2019 --- fairseq/data/token_dictionary.py | 11 ++++++--- fairseq/tasks/speech_recognition.py | 24 +++++++++++++++---- speech_recognition.py | 6 ++++- speech_train.py | 36 ++++++++++++++--------------- 4 files changed, 49 insertions(+), 28 deletions(-) mode change 100644 => 100755 speech_train.py diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 0c44bcc2b..faaabb3a2 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -40,9 +40,14 @@ def token_string(i): else: return self[i] - sent = ' '.join(token_string(i) for i in tensor if i != self.eos() and \ - i != self.pad()) - if bpe_symbol is not None: + if bpe_symbol == 'sentencepiece': + sent = ''.join(token_string(i) for i in tensor if i != self.eos() \ + and i != self.pad()) + sent = sent.replace('\u2581', ' ').strip() + else: + sent = ' '.join(token_string(i) for i in tensor if i != self.eos() \ + and i != self.pad()) + if bpe_symbol is not None and bpe_symbol != 'sentencepiece': sent = (sent + ' ').replace(bpe_symbol, '').rstrip() return sent diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index b4b56d002..65f8fde64 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -32,8 +32,8 @@ class SpeechRecognitionTask(FairseqTask): .. note:: - The speech recognition task is compatible with :mod:`train.py `, - :mod:`generate.py ` and :mod:`interactive.py `. + The speech recognition task is compatible with :mod:`speech-train`, + :mod:`speech-recognition` and :mod:`fairseq-interactive`. The speech recognition task provides the following additional command-line arguments: @@ -81,6 +81,21 @@ def add_args(parser): help='amount to upsample primary dataset') # fmt: off + @classmethod + def load_dictionary(cls, filename, non_lang_syms=None): + """Load the dictionary from the filename + Args: + filename (str): the filename + non_lang_syms (str): non_lang_syms filename + """ + return TokenDictionary.load(filename, f_non_lang_syms=non_lang_syms) + + @classmethod + def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): + """Disable this method + """ + raise NotImplementedError + @staticmethod def load_pretrained_model(path, dict_path, non_lang_syms=None, arg_overrides=None): @@ -88,7 +103,7 @@ def load_pretrained_model(path, dict_path, non_lang_syms=None, args = model['args'] state_dict = model['model'] args = utils.override_model_args(args, arg_overrides) - dict = TokenDictionary.load(dict_path, f_non_lang_syms=non_lang_syms) + dict = cls.load_dictionary(dict_path, non_lang_syms=non_lang_syms) task = SpeechRecognitionTask(args, dict) model = task.build_model(args) @@ -113,8 +128,7 @@ def setup_task(cls, args, **kwargs): # load dictionaries dict_path = os.path.join(os.path.dirname(args.text_files[0]), 'dict.txt') if args.dict is None else args.dict - dict = TokenDictionary.load(dict_path, - f_non_lang_syms=args.non_lang_syms) + dict = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) print('| dictionary: {} types'.format(len(dict))) return cls(args, dict) diff --git a/speech_recognition.py b/speech_recognition.py index 794945dd7..a87c42bdc 100644 --- a/speech_recognition.py +++ b/speech_recognition.py @@ -196,10 +196,14 @@ def print_options_meaning_changes(args): print('| --print-alignment is set to plot attentions') -if __name__ == '__main__': +def cli_main(): parser = options.get_generation_parser(default_task='speech_recognition') parser.add_argument('--output-dir', metavar='DIR', required=True, help='path to output results') args = options.parse_args_and_arch(parser) print_options_meaning_changes(args) main(args) + + +if __name__ == '__main__': + cli_main() diff --git a/speech_train.py b/speech_train.py old mode 100644 new mode 100755 index 4b72bb229..a1a312e68 --- a/speech_train.py +++ b/speech_train.py @@ -1,5 +1,6 @@ -#!/usr/bin/env python3 -u +#!/usr/bin/env python3 # Copyright (c) 2017-present, Facebook, Inc. +# 2018-present, Yiming Wang # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in @@ -24,7 +25,7 @@ from fairseq.utils import import_user_module -def main(args): +def main(args, init_distributed=False): import_user_module(args) if args.max_tokens is None: @@ -41,6 +42,12 @@ def main(args): # Load dataset splits load_dataset_splits(task, ['train', 'valid']) + # Initialize distributed training (after data loading) + if init_distributed: + import socket + args.distributed_rank = distributed_utils.distributed_init(args) + print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) + # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) @@ -384,13 +391,10 @@ def load_dataset_splits(task, splits): def distributed_main(i, args): - import socket args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn args.distributed_rank = i - args.distributed_rank = distributed_utils.distributed_init(args) - print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) - main(args) + main(args, init_distributed=True) def print_options_meaning_changes(args): @@ -400,7 +404,7 @@ def print_options_meaning_changes(args): print('| --max-tokens is the maximum number of input frames in a batch') -if __name__ == '__main__': +def cli_main(): parser = options.get_training_parser(default_task='speech_recognition') args = options.parse_args_and_arch(parser) print_options_meaning_changes(args) @@ -416,18 +420,8 @@ def print_options_meaning_changes(args): port = random.randint(10000, 20000) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_rank = None # set based on device id - print( - '''| NOTE: you may get better performance with: - - python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...) - '''.format( - ngpu=args.distributed_world_size, - no_c10d=( - '--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d' - else '' - ), - ) - ) + if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': + print('| NOTE: you may get better performance with: --ddp-backend=no_c10d') torch.multiprocessing.spawn( fn=distributed_main, args=(args, ), @@ -436,3 +430,7 @@ def print_options_meaning_changes(args): else: # single GPU training main(args) + + +if __name__ == '__main__': + cli_main() From 0802783238eacdbe8f09ce1a8a9da54b1ca18aee Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 9 Feb 2019 17:00:11 -0500 Subject: [PATCH 011/119] Add wsj data prep recipe from kaldi and espnet --- .gitignore | 3 - examples/asr_wsj/cmd.sh | 20 ++ examples/asr_wsj/conf/fbank.conf | 2 + examples/asr_wsj/conf/pitch.conf | 1 + examples/asr_wsj/local/find_transcripts.pl | 64 ++++++ examples/asr_wsj/local/flist2scp.pl | 31 +++ examples/asr_wsj/local/ndx2flist.pl | 62 +++++ .../asr_wsj/local/normalize_transcript.pl | 59 +++++ .../asr_wsj/{ => local}/wer_output_filter | 0 examples/asr_wsj/local/wsj_data_prep.sh | 214 ++++++++++++++++++ examples/asr_wsj/local/wsj_format_data.sh | 35 +++ examples/asr_wsj/path.sh | 3 +- examples/asr_wsj/run.sh | 156 ++++++++----- examples/asr_wsj/steps | 1 + examples/asr_wsj/utils | 1 + fairseq/criterions/cross_entropy_with_wer.py | 2 +- .../label_smoothed_cross_entropy_with_wer.py | 197 ++++++++++++++++ fairseq/models/speech_lstm.py | 17 +- fairseq/tasks/speech_recognition.py | 7 +- speech_recognition.py | 5 +- speech_tools/.gitignore | 1 + speech_tools/dump.sh | 84 +++++++ speech_tools/parse_options.sh | 97 -------- speech_train.py | 8 +- 24 files changed, 897 insertions(+), 173 deletions(-) create mode 100644 examples/asr_wsj/cmd.sh create mode 100644 examples/asr_wsj/conf/fbank.conf create mode 100644 examples/asr_wsj/conf/pitch.conf create mode 100755 examples/asr_wsj/local/find_transcripts.pl create mode 100755 examples/asr_wsj/local/flist2scp.pl create mode 100755 examples/asr_wsj/local/ndx2flist.pl create mode 100755 examples/asr_wsj/local/normalize_transcript.pl rename examples/asr_wsj/{ => local}/wer_output_filter (100%) create mode 100755 examples/asr_wsj/local/wsj_data_prep.sh create mode 100755 examples/asr_wsj/local/wsj_format_data.sh create mode 120000 examples/asr_wsj/steps create mode 120000 examples/asr_wsj/utils create mode 100644 fairseq/criterions/label_smoothed_cross_entropy_with_wer.py mode change 100644 => 100755 speech_recognition.py create mode 100644 speech_tools/.gitignore create mode 100755 speech_tools/dump.sh delete mode 100755 speech_tools/parse_options.sh diff --git a/.gitignore b/.gitignore index 4cfe9214b..fbe71542a 100644 --- a/.gitignore +++ b/.gitignore @@ -137,6 +137,3 @@ wandb/ # emacs saves *~ - -# speech related -speech_tools/kaldi diff --git a/examples/asr_wsj/cmd.sh b/examples/asr_wsj/cmd.sh new file mode 100644 index 000000000..008ac4efa --- /dev/null +++ b/examples/asr_wsj/cmd.sh @@ -0,0 +1,20 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +#export train_cmd="run.pl --mem 2G" +#export cuda_cmd="run.pl --mem 2G --gpu 1" +#export decode_cmd="run.pl --mem 4G" + +# JHU setup +export train_cmd="queue.pl --mem 2G" +export cuda_cmd="queue.pl --mem 2G --gpu 1 --config conf/gpu.conf" +export decode_cmd="queue.pl --mem 4G" diff --git a/examples/asr_wsj/conf/fbank.conf b/examples/asr_wsj/conf/fbank.conf new file mode 100644 index 000000000..82ac7bd0d --- /dev/null +++ b/examples/asr_wsj/conf/fbank.conf @@ -0,0 +1,2 @@ +--sample-frequency=16000 +--num-mel-bins=80 diff --git a/examples/asr_wsj/conf/pitch.conf b/examples/asr_wsj/conf/pitch.conf new file mode 100644 index 000000000..e959a19d5 --- /dev/null +++ b/examples/asr_wsj/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/examples/asr_wsj/local/find_transcripts.pl b/examples/asr_wsj/local/find_transcripts.pl new file mode 100755 index 000000000..6429411b8 --- /dev/null +++ b/examples/asr_wsj/local/find_transcripts.pl @@ -0,0 +1,64 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + + +# This program takes on its standard input a list of utterance +# id's, one for each line. (e.g. 4k0c030a is a an utterance id). +# It takes as +# Extracts from the dot files the transcripts for a given +# dataset (represented by a file list). +# + +@ARGV == 1 || die "find_transcripts.pl dot_files_flist < utterance_ids > transcripts"; +$dot_flist = shift @ARGV; + +open(L, "<$dot_flist") || die "Opening file list of dot files: $dot_flist\n"; +while(){ + chop; + m:\S+/(\w{6})00.dot: || die "Bad line in dot file list: $_"; + $spk = $1; + $spk2dot{$spk} = $_; +} + + + +while(){ + chop; + $uttid = $_; + $uttid =~ m:(\w{6})\w\w: || die "Bad utterance id $_"; + $spk = $1; + if($spk ne $curspk) { + %utt2trans = { }; # Don't keep all the transcripts in memory... + $curspk = $spk; + $dotfile = $spk2dot{$spk}; + defined $dotfile || die "No dot file for speaker $spk\n"; + open(F, "<$dotfile") || die "Error opening dot file $dotfile\n"; + while() { + $_ =~ m:(.+)\((\w{8})\)\s*$: || die "Bad line $_ in dot file $dotfile (line $.)\n"; + $trans = $1; + $utt = $2; + $utt2trans{$utt} = $trans; + } + } + if(!defined $utt2trans{$uttid}) { + print STDERR "No transcript for utterance $uttid (current dot file is $dotfile)\n"; + } else { + print "$uttid $utt2trans{$uttid}\n"; + } +} + + diff --git a/examples/asr_wsj/local/flist2scp.pl b/examples/asr_wsj/local/flist2scp.pl new file mode 100755 index 000000000..234e4add1 --- /dev/null +++ b/examples/asr_wsj/local/flist2scp.pl @@ -0,0 +1,31 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# takes in a file list with lines like +# /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1 +# and outputs an scp in kaldi format with lines like +# 4k0c030a /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1 +# (the first thing is the utterance-id, which is the same as the basename of the file. + + +while(<>){ + m:^\S+/(\w+)\.[wW][vV]1$: || die "Bad line $_"; + $id = $1; + $id =~ tr/A-Z/a-z/; # Necessary because of weirdness on disk 13-16.1 (uppercase filenames) + print "$id $_"; +} + diff --git a/examples/asr_wsj/local/ndx2flist.pl b/examples/asr_wsj/local/ndx2flist.pl new file mode 100755 index 000000000..48fc3dec1 --- /dev/null +++ b/examples/asr_wsj/local/ndx2flist.pl @@ -0,0 +1,62 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This program takes as its standard input an .ndx file from the WSJ corpus that looks +# like this: +#;; File: tr_s_wv1.ndx, updated 04/26/94 +#;; +#;; Index for WSJ0 SI-short Sennheiser training data +#;; Data is read WSJ sentences, Sennheiser mic. +#;; Contains 84 speakers X (~100 utts per speaker MIT/SRI and ~50 utts +#;; per speaker TI) = 7236 utts +#;; +#11_1_1:wsj0/si_tr_s/01i/01ic0201.wv1 +#11_1_1:wsj0/si_tr_s/01i/01ic0202.wv1 +#11_1_1:wsj0/si_tr_s/01i/01ic0203.wv1 + +#and as command-line arguments it takes the names of the WSJ disk locations, e.g.: +#/mnt/matylda2/data/WSJ0/11-1.1 /mnt/matylda2/data/WSJ0/11-10.1 ... etc. +# It outputs a list of absolute pathnames (it does this by replacing e.g. 11_1_1 with +# /mnt/matylda2/data/WSJ0/11-1.1. +# It also does a slight fix because one of the WSJ disks (WSJ1/13-16.1) was distributed with +# uppercase rather than lower case filenames. + +foreach $fn (@ARGV) { + $fn =~ m:.+/([0-9\.\-]+)/?$: || die "Bad command-line argument $fn\n"; + $disk_id=$1; + $disk_id =~ tr/-\./__/; # replace - and . with - so 11-10.1 becomes 11_10_1 + $fn =~ s:/$::; # Remove final slash, just in case it is present. + $disk2fn{$disk_id} = $fn; +} + +while(){ + if(m/^;/){ next; } # Comment. Ignore it. + else { + m/^([0-9_]+):\s*(\S+)$/ || die "Could not parse line $_"; + $disk=$1; + if(!defined $disk2fn{$disk}) { + die "Disk id $disk not found"; + } + $filename = $2; # as a subdirectory of the distributed disk. + if($disk eq "13_16_1" && `hostname` =~ m/fit.vutbr.cz/) { + # The disk 13-16.1 has been uppercased for some reason, on the + # BUT system. This is a fix specifically for that case. + $filename =~ tr/a-z/A-Z/; # This disk contains all uppercase filenames. Why? + } + print "$disk2fn{$disk}/$filename\n"; + } +} diff --git a/examples/asr_wsj/local/normalize_transcript.pl b/examples/asr_wsj/local/normalize_transcript.pl new file mode 100755 index 000000000..09cee0617 --- /dev/null +++ b/examples/asr_wsj/local/normalize_transcript.pl @@ -0,0 +1,59 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This takes data from the standard input that's unnormalized transcripts in the format +# 4k2c0308 Of course there isn\'t any guarantee the company will keep its hot hand [misc_noise] +# 4k2c030a [loud_breath] And new hardware such as the set of personal computers I\. B\. M\. introduced last week can lead to unexpected changes in the software business [door_slam] +# and outputs normalized transcripts. +# c.f. /mnt/matylda2/data/WSJ0/11-10.1/wsj0/transcrp/doc/dot_spec.doc + +@ARGV == 1 || die "usage: normalize_transcript.pl noise_word < transcript > transcript2"; +$noise_word = shift @ARGV; + +while() { + $_ =~ m:^(\S+) (.+): || die "bad line $_"; + $utt = $1; + $trans = $2; + print "$utt"; + foreach $w (split (" ",$trans)) { + $w =~ tr:a-z:A-Z:; # Upcase everything to match the CMU dictionary. . + $w =~ s:\\::g; # Remove backslashes. We don't need the quoting. + $w =~ s:^\%PERCENT$:PERCENT:; # Normalization for Nov'93 test transcripts. + $w =~ s:^\.POINT$:POINT:; # Normalization for Nov'93 test transcripts. + if($w =~ m:^\[\<\w+\]$: || # E.g. [\]$: || # E.g. [door_slam>], this means a door slammed in the next word. Delete. + $w =~ m:\[\w+/\]$: || # E.g. [phone_ring/], which indicates the start of this phenomenon. + $w =~ m:\[\/\w+]$: || # E.g. [/phone_ring], which indicates the end of this phenomenon. + $w eq "~" || # This is used to indicate truncation of an utterance. Not a word. + $w eq ".") { # "." is used to indicate a pause. Silence is optional anyway so not much + # point including this in the transcript. + next; # we won't print this word. + } elsif($w =~ m:\[\w+\]:) { # Other noises, e.g. [loud_breath]. + print " $noise_word"; + } elsif($w =~ m:^\<([\w\']+)\>$:) { + # e.g. replace with and. (the <> means verbal deletion of a word).. but it's pronounced. + print " $1"; + } elsif($w eq "--DASH") { + print " -DASH"; # This is a common issue; the CMU dictionary has it as -DASH. +# } elsif($w =~ m:(.+)\-DASH$:) { # E.g. INCORPORATED-DASH... seems the DASH gets combined with previous word +# print " $1 -DASH"; + } else { + print " $w"; + } + } + print "\n"; +} diff --git a/examples/asr_wsj/wer_output_filter b/examples/asr_wsj/local/wer_output_filter similarity index 100% rename from examples/asr_wsj/wer_output_filter rename to examples/asr_wsj/local/wer_output_filter diff --git a/examples/asr_wsj/local/wsj_data_prep.sh b/examples/asr_wsj/local/wsj_data_prep.sh new file mode 100755 index 000000000..04f2f6390 --- /dev/null +++ b/examples/asr_wsj/local/wsj_data_prep.sh @@ -0,0 +1,214 @@ +#!/bin/bash + +# Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + + +if [ $# -le 3 ]; then + echo "Arguments should be a list of WSJ directories, see ../run.sh for example." + exit 1; +fi + + +dir=`pwd`/data/local/data +lmdir=`pwd`/data/local/nist_lm +mkdir -p $dir $lmdir +local=`pwd`/local +utils=`pwd`/utils + +. ./path.sh # Needed for KALDI_ROOT +sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe +if [ ! -x $sph2pipe ]; then + echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; + exit 1; +fi + +if [ -z $IRSTLM ] ; then + export IRSTLM=$KALDI_ROOT/tools/irstlm/ +fi +export PATH=${PATH}:$IRSTLM/bin +if ! command -v prune-lm >/dev/null 2>&1 ; then + echo "$0: Error: the IRSTLM is not available or compiled" >&2 + echo "$0: Error: We used to install it by default, but." >&2 + echo "$0: Error: this is no longer the case." >&2 + echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 + echo "$0: Error: and run extras/install_irstlm.sh" >&2 + exit 1 +fi + +cd $dir +# Make directory of links to the WSJ disks such as 11-13.1. This relies on the command +# line arguments being absolute pathnames. +rm -r links/ 2>/dev/null +mkdir links/ +ln -s $* links + +# Do some basic checks that we have what we expected. +if [ ! -d links/11-13.1 -o ! -d links/13-34.1 -o ! -d links/11-2.1 ]; then + echo "wsj_data_prep.sh: Spot check of command line arguments failed" + echo "Command line arguments must be absolute pathnames to WSJ directories" + echo "with names like 11-13.1." + echo "Note: if you have old-style WSJ distribution," + echo "local/cstr_wsj_data_prep.sh may work instead, see run.sh for example." + exit 1; +fi + +# This version for SI-84 + +cat links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ + $local/ndx2flist.pl $* | sort | \ + grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si84.flist + +nl=`cat train_si84.flist | wc -l` +[ "$nl" -eq 7138 ] || echo "Warning: expected 7138 lines in train_si84.flist, got $nl" + +# This version for SI-284 +cat links/13-34.1/wsj1/doc/indices/si_tr_s.ndx \ + links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ + $local/ndx2flist.pl $* | sort | \ + grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si284.flist + +nl=`cat train_si284.flist | wc -l` +[ "$nl" -eq 37416 ] || echo "Warning: expected 37416 lines in train_si284.flist, got $nl" + +# Now for the test sets. +# links/13-34.1/wsj1/doc/indices/readme.doc +# describes all the different test sets. +# Note: each test-set seems to come in multiple versions depending +# on different vocabulary sizes, verbalized vs. non-verbalized +# pronunciations, etc. We use the largest vocab and non-verbalized +# pronunciations. +# The most normal one seems to be the "baseline 60k test set", which +# is h1_p0. + +# Nov'92 (333 utts) +# These index files have a slightly different format; +# have to add .wv1 +cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_20.ndx | \ + $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ + sort > test_eval92.flist + +# Nov'92 (330 utts, 5k vocab) +cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_05.ndx | \ + $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ + sort > test_eval92_5k.flist + +# Nov'93: (213 utts) +# Have to replace a wrong disk-id. +cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h1_p0.ndx | \ + sed s/13_32_1/13_33_1/ | \ + $local/ndx2flist.pl $* | sort > test_eval93.flist + +# Nov'93: (213 utts, 5k) +cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h2_p0.ndx | \ + sed s/13_32_1/13_33_1/ | \ + $local/ndx2flist.pl $* | sort > test_eval93_5k.flist + +# Dev-set for Nov'93 (503 utts) +cat links/13-34.1/wsj1/doc/indices/h1_p0.ndx | \ + $local/ndx2flist.pl $* | sort > test_dev93.flist + +# Dev-set for Nov'93 (513 utts, 5k vocab) +cat links/13-34.1/wsj1/doc/indices/h2_p0.ndx | \ + $local/ndx2flist.pl $* | sort > test_dev93_5k.flist + + +# Dev-set Hub 1,2 (503, 913 utterances) + +# Note: the ???'s below match WSJ and SI_DT, or wsj and si_dt. +# Sometimes this gets copied from the CD's with upcasing, don't know +# why (could be older versions of the disks). +find `readlink links/13-16.1`/???1/??_??_20 -print | grep -i ".wv1" | sort > dev_dt_20.flist +find `readlink links/13-16.1`/???1/??_??_05 -print | grep -i ".wv1" | sort > dev_dt_05.flist + + +# Finding the transcript files: +for x in $*; do find -L $x -iname '*.dot'; done > dot_files.flist + +# Convert the transcripts into our format (no normalization yet) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + $local/flist2scp.pl $x.flist | sort > ${x}_sph.scp + cat ${x}_sph.scp | awk '{print $1}' | $local/find_transcripts.pl dot_files.flist > $x.trans1 +done + +# Do some basic normalization steps. At this point we don't remove OOVs-- +# that will be done inside the training scripts, as we'd like to make the +# data-preparation stage independent of the specific lexicon used. +noiseword=""; +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat $x.trans1 | $local/normalize_transcript.pl $noiseword | sort > $x.txt || exit 1; +done + +# Create scp's with wav's. (the wv1 in the distribution is not really wav, it is sph.) +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < ${x}_sph.scp > ${x}_wav.scp +done + +# Make the utt2spk and spk2utt files. +for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + cat ${x}_sph.scp | awk '{print $1}' | perl -ane 'chop; m:^...:; print "$_ $&\n";' > $x.utt2spk + cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; +done + + +#in case we want to limit lm's on most frequent words, copy lm training word frequency list +cp links/13-32.1/wsj1/doc/lng_modl/vocab/wfl_64.lst $lmdir +chmod u+w $lmdir/*.lst # had weird permissions on source. + +# The 20K vocab, open-vocabulary language model (i.e. the one with UNK), without +# verbalized pronunciations. This is the most common test setup, I understand. + +cp links/13-32.1/wsj1/doc/lng_modl/base_lm/bcb20onp.z $lmdir/lm_bg.arpa.gz || exit 1; +chmod u+w $lmdir/lm_bg.arpa.gz + +# trigram would be: +cat links/13-32.1/wsj1/doc/lng_modl/base_lm/tcb20onp.z | \ + perl -e 'while(<>){ if(m/^\\data\\/){ print; last; } } while(<>){ print; }' | \ + gzip -c -f > $lmdir/lm_tg.arpa.gz || exit 1; + +prune-lm --threshold=1e-7 $lmdir/lm_tg.arpa.gz $lmdir/lm_tgpr.arpa || exit 1; +gzip -f $lmdir/lm_tgpr.arpa || exit 1; + +# repeat for 5k language models +cp links/13-32.1/wsj1/doc/lng_modl/base_lm/bcb05onp.z $lmdir/lm_bg_5k.arpa.gz || exit 1; +chmod u+w $lmdir/lm_bg_5k.arpa.gz + +# trigram would be: !only closed vocabulary here! +cp links/13-32.1/wsj1/doc/lng_modl/base_lm/tcb05cnp.z $lmdir/lm_tg_5k.arpa.gz || exit 1; +chmod u+w $lmdir/lm_tg_5k.arpa.gz +gunzip $lmdir/lm_tg_5k.arpa.gz +tail -n 4328839 $lmdir/lm_tg_5k.arpa | gzip -c -f > $lmdir/lm_tg_5k.arpa.gz +rm $lmdir/lm_tg_5k.arpa + +prune-lm --threshold=1e-7 $lmdir/lm_tg_5k.arpa.gz $lmdir/lm_tgpr_5k.arpa || exit 1; +gzip -f $lmdir/lm_tgpr_5k.arpa || exit 1; + + +if [ ! -f wsj0-train-spkrinfo.txt ] || [ `cat wsj0-train-spkrinfo.txt | wc -l` -ne 134 ]; then + rm wsj0-train-spkrinfo.txt + ! wget https://catalog.ldc.upenn.edu/docs/LDC93S6A/wsj0-train-spkrinfo.txt && \ + echo "Getting wsj0-train-spkrinfo.txt from backup location" && \ + wget --no-check-certificate https://sourceforge.net/projects/kaldi/files/wsj0-train-spkrinfo.txt +fi + +if [ ! -f wsj0-train-spkrinfo.txt ]; then + echo "Could not get the spkrinfo.txt file from LDC website (moved)?" + echo "This is possibly omitted from the training disks; couldn't find it." + echo "Everything else may have worked; we just may be missing gender info" + echo "which is only needed for VTLN-related diagnostics anyway." + exit 1 +fi +# Note: wsj0-train-spkrinfo.txt doesn't seem to be on the disks but the +# LDC put it on the web. Perhaps it was accidentally omitted from the +# disks. + +cat links/11-13.1/wsj0/doc/spkrinfo.txt \ + links/13-32.1/wsj1/doc/evl_spok/spkrinfo.txt \ + links/13-34.1/wsj1/doc/dev_spok/spkrinfo.txt \ + links/13-34.1/wsj1/doc/train/spkrinfo.txt \ + ./wsj0-train-spkrinfo.txt | \ + perl -ane 'tr/A-Z/a-z/; m/^;/ || print;' | \ + awk '{print $1, $2}' | grep -v -- -- | sort | uniq > spk2gender + + +echo "Data preparation succeeded" diff --git a/examples/asr_wsj/local/wsj_format_data.sh b/examples/asr_wsj/local/wsj_format_data.sh new file mode 100755 index 000000000..d567fd1bd --- /dev/null +++ b/examples/asr_wsj/local/wsj_format_data.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Copyright 2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) +# 2015 Guoguo Chen +# Apache 2.0 + +# This script takes data prepared in a corpus-dependent way +# in data/local/, and converts it into the "canonical" form, +# in various subdirectories of data/, e.g. data/lang, data/lang_test_ug, +# data/train_si284, data/train_si84, etc. + +# Don't bother doing train_si84 separately (although we have the file lists +# in data/local/) because it's just the first 7138 utterances in train_si284. +# We'll create train_si84 after doing the feature extraction. + +lang_suffix= + +echo "$0 $@" # Print the command line for logging +. utils/parse_options.sh || exit 1; + +. ./path.sh || exit 1; + +echo "Preparing train and test data" +srcdir=data/local/data + +for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + mkdir -p data/$x + cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1; + cp $srcdir/$x.txt data/$x/text || exit 1; + cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1; + cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1; + utils/filter_scp.pl data/$x/spk2utt $srcdir/spk2gender > data/$x/spk2gender || exit 1; +done + +echo "Succeeded in formatting data." diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh index 621112944..19a6eb6eb 100644 --- a/examples/asr_wsj/path.sh +++ b/examples/asr_wsj/path.sh @@ -10,5 +10,6 @@ export LC_ALL=C # END export PATH=~/anaconda3/bin:$PATH -export PATH=$MAIN_ROOT:$PATH +export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH +export PYTHONUNBUFFERED=1 diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index adc7b3457..37cdade9b 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -11,48 +11,86 @@ set -e -o pipefail stage=0 free_gpu= -data_dir=data-bin/wsj -exp_dir=exp/wsj/lstm +affix= train_set=train_si284 valid_set=test_dev93 test_set=test_eval92 checkpoint=checkpoint_best.pt validate_on_train=false +dumpdir=data/dump # directory to dump full features +# feature configuration +do_delta=false -if [ -f ./path.sh ]; then - . ./path.sh -else - . ./examples/asr_wsj/path.sh -fi -if [ -f ../../speech_tools/parse_options.sh ]; then - . ../../speech_tools/parse_options.sh -else - . ./speech_tools/parse_options.sh +# data +wsj0= +wsj1= +if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then + wsj0=/export/corpora5/LDC/LDC93S6B + wsj1=/export/corpora5/LDC/LDC94S13B fi -valid_subset=valid -if $validate_on_train; then - valid_subset="$valid_subset train" +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +dir=exp/lstm${affix:+_$affix} + +if [ ${stage} -le 0 ]; then + ### Task dependent. You have to make data the following preparation part by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "Stage 0: Data Preparation" + local/wsj_data_prep.sh ${wsj0}/??-{?,??}.? ${wsj1}/??-{?,??}.? + local/wsj_format_data.sh fi -dict=$data_dir/lang/${train_set}_units.txt -nlsyms=$data_dir/lang/non_lang_syms.txt -train_text=$data_dir/$train_set/text +train_feat_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${train_feat_dir} +valid_feat_dir=${dumpdir}/${valid_set}/delta${do_delta}; mkdir -p ${valid_feat_dir} +test_feat_dir=${dumpdir}/${test_set}/delta${do_delta}; mkdir -p ${test_feat_dir} if [ ${stage} -le 1 ]; then - echo "Stage 1: Dictionary Preparation and Text Tokenization" - mkdir -p $data_dir/lang + ### Task dependent. You have to design training and dev sets by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "Stage 1: Feature Generation" + fbankdir=fbank + # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame + for x in $train_set $valid_set $test_set; do + steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 10 --write_utt2num_frames true \ + data/${x} exp/make_fbank/${x} ${fbankdir} + done + + # compute global CMVN + compute-cmvn-stats scp:data/${train_set}/feats.scp data/${train_set}/cmvn.ark + + # dump features for training + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d ${train_feat_dir}/storage ]; then + utils/create_split_dir.pl \ + /export/b{10,11,12,13}/${USER}/fairseq-data/egs/asr_wsj/dump/${train_set}/delta${do_delta}/storage \ + ${train_feat_dir}/storage + fi + dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ + data/${train_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/train ${train_feat_dir} + dump.sh --cmd "$train_cmd" --nj 4 --do_delta $do_delta \ + data/${valid_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/valid ${valid_feat_dir} + dump.sh --cmd "$train_cmd" --nj 4 --do_delta $do_delta \ + data/${test_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/test ${test_feat_dir} +fi + +dict=data/lang/${train_set}_units.txt +nlsyms=data/lang/non_lang_syms.txt +train_text=data/$train_set/text +if [ ${stage} -le 2 ]; then + echo "Stage 2: Dictionary Preparation and Text Tokenization" + mkdir -p data/lang - echo "Making a non-linguistic symbol list..." + echo "$0: making a non-linguistic symbol list..." cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "<" > $nlsyms cat $nlsyms - echo "Making a dictionary and tokenizing text for train/valid/test set..." + echo "$0: making a dictionary and tokenizing text for train/valid/test set..." for dataset in $train_set $valid_set $test_set; do - text=$data_dir/$dataset/text - token_text=$data_dir/$dataset/token_text - python3 speech_tools/text2token.py --skip-ncols 1 --space "" \ - --non-lang-syms $nlsyms $text > $token_text + text=data/$dataset/text + token_text=data/$dataset/token_text + text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms $text > $token_text if [ "$dataset" == "$train_set" ]; then cut -f 2- -d" " $token_text | tr " " "\n" | grep -v -e '^\s*$' | sort | \ uniq -c | awk '{print $2,$1}' > $dict @@ -61,29 +99,34 @@ if [ ${stage} -le 1 ]; then done fi -train_feat=$data_dir/dump/$train_set/deltafalse/feats.scp -train_token_text=$data_dir/$train_set/token_text -valid_feat=$data_dir/dump/$valid_set/deltafalse/feats.scp -valid_token_text=$data_dir/$valid_set/token_text -if [ ${stage} -le 2 ]; then - echo "Stage 2: Model Training" - mkdir -p $exp_dir/logs - log_file=$exp_dir/logs/train.log - [ -f $exp_dir/checkpoint_last.pt ] && log_file="-a $log_file" +train_feat=$train_feat_dir/feats.scp +train_token_text=data/$train_set/token_text +valid_feat=$valid_feat_dir/feats.scp +valid_token_text=data/$valid_set/token_text +if [ ${stage} -le 3 ]; then + echo "Stage 3: Model Training" + valid_subset=valid + if $validate_on_train; then + valid_subset="$valid_subset train" + fi + mkdir -p $dir/logs + log_file=$dir/logs/train.log + [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" opts="" - [ -f examples/asr_wsj/wer_output_filter ] && \ - opts="$opts --wer-output-filter examples/asr_wsj/wer_output_filter" - [ -z "$free_gpu" ] && free_gpu=$(free-gpu) - CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_train.py --seed 1 \ - --log-interval 500 --log-format "simple" --print-training-sample-interval 500 \ + [ -f local/wer_output_filter ] && \ + opts="$opts --wer-output-filter local/wer_output_filter" + [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) + [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ + --log-interval 500 --log-format "simple" --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 \ --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ --max-epoch 20 --optimizer "adam" --lr 0.001 --weight-decay 0.0 \ - --lr-scheduler "reduce_lr_on_plateau" --lr-shrink 0.1 --min-lr "1e-15" \ - --save-dir $exp_dir --save-interval-updates 200 --keep-interval-updates 10 \ - --keep-last-epochs 5 --validate-interval 1 \ - --arch "speech_conv_lstm_wsj" --criterion "cross_entropy_with_wer" \ + --lr-scheduler "reduce_lr_on_plateau" --lr-shrink 0.5 --min-lr "1e-8" \ + --save-dir $dir --restore-file "checkpoint_last.pt" --save-interval-updates 200 \ + --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --arch "speech_conv_lstm_wsj" --criterion "label_smoothed_cross_entropy_with_wer" --label-smoothing 0.05 \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ --dict $dict --non-lang-syms $nlsyms \ @@ -91,22 +134,27 @@ if [ ${stage} -le 2 ]; then exit 0 fi -if [ ${stage} -le 3 ]; then - echo "Stage 3: Decoding" +if [ ${stage} -le 4 ]; then + echo "Stage 4: Decoding" opts="" - [ -f examples/asr_wsj/wer_output_filter ] && \ - opts="$opts --wer-output-filter examples/asr_wsj/wer_output_filter" - [ -z "$free_gpu" ] && free_gpu=$(free-gpu) + [ -f local/wer_output_filter ] && \ + opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $valid_set $test_set; do - feat=$data_dir/dump/$dataset/deltafalse/feats.scp - text=$data_dir/$dataset/token_text - CUDA_VISIBLE_DEVICES=$free_gpu python3 -u speech_recognition.py \ + if [ "$dataset" == "$valid_set" ]; then + feat=$valid_feat_dir/feats.scp + elif [ "$dataset" == "$test_set" ]; then + feat=$test_feat_dir/feats.scp + fi + text=data/$dataset/token_text + [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) + [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; + CUDA_VISIBLE_DEVICES=$free_gpu speech_recognition.py \ --max-tokens 45000 --max-sentences 32 --num-shards 1 --shard-id 0 \ --test-feat-files $feat --test-text-files $text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ - --path $exp_dir/$checkpoint --beam 10 --max-len-a 0.5 --max-len-b 0 \ - --lenpen 1.0 --output-dir $exp_dir/decode_$dataset --print-alignment $opts \ - 2>&1 | tee $exp_dir/logs/decode_$dataset.log + --path $dir/$checkpoint --beam 15 --max-len-a 0.5 --max-len-b 0 \ + --lenpen 1.0 --output-dir $dir/decode_$dataset --print-alignment $opts \ + 2>&1 | tee $dir/logs/decode_$dataset.log done fi diff --git a/examples/asr_wsj/steps b/examples/asr_wsj/steps new file mode 120000 index 000000000..ec9b528ac --- /dev/null +++ b/examples/asr_wsj/steps @@ -0,0 +1 @@ +../../speech_tools/kaldi/egs/wsj/s5/steps \ No newline at end of file diff --git a/examples/asr_wsj/utils b/examples/asr_wsj/utils new file mode 120000 index 000000000..ea44d93b9 --- /dev/null +++ b/examples/asr_wsj/utils @@ -0,0 +1 @@ +../../speech_tools/kaldi/egs/wsj/s5/utils \ No newline at end of file diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 30f155e56..ee595a5a4 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -17,7 +17,7 @@ from speech_tools.utils import Tokenizer from . import FairseqCriterion, register_criterion -from .cross_entropy import CrossEntropyCriterion +from .cross_entropy import CrossEntropyCriterion @register_criterion('cross_entropy_with_wer') diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py new file mode 100644 index 000000000..ce53d42fc --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -0,0 +1,197 @@ +# Copyright (c) 2019-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +import numpy as np +import torch + +from fairseq import utils, wer +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder + +from speech_tools.utils import Tokenizer + +from . import FairseqCriterion, register_criterion +from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion + + +@register_criterion('label_smoothed_cross_entropy_with_wer') +class LabelSmoothedCrossEntropyWithWERCriterion(LabelSmoothedCrossEntropyCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + + dict = task.dict if hasattr(task, 'dict') else getattr(task, 'tgt_dict') + self.scorer = wer.Scorer(dict, + wer_output_filter=task.args.wer_output_filter) + self.train_tgt_dataset = task.dataset(args.train_subset).tgt + self.valid_tgt_dataset = None + self.num_updates = -1 + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + LabelSmoothedCrossEntropyCriterion.add_args(parser) + parser.add_argument('--print-training-sample-interval', type=int, + metavar='N', dest='print_interval', default=500, + help='print a training sample (reference + ' + 'prediction) every this number of updates') + # fmt: on + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample; periodically print out + randomly sampled predictions if model is in training mode, otherwise + aggregate word error stats for validation. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + dict = self.scorer.dict + if model.training: + net_output = model(**sample['net_input']) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + else: + assert isinstance(model.decoder, FairseqIncrementalDecoder) + incremental_states = {} + encoder_input = { + k: v for k, v in sample['net_input'].items() + if k != 'prev_output_tokens' + } + encoder_out = model.encoder(**encoder_input) + target = sample['target'] + # make the maximum decoding length equal to at least the length of + # target, and the length of encoder_out if possible + # and at least the length of target + maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) + tokens = target.new_full([target.size(0), maxlen + 2], dict.pad()) + tokens[:, 0] = dict.eos() + lprobs = [] + attn = [] if model.decoder.need_attn else None + dummy_log_probs = encoder_out['encoder_out'][0].new_full( + [target.size(0), len(dict)], -np.log(len(dict))) + for step in range(maxlen + 1): # one extra step for EOS marker + is_eos = tokens[:, step].eq(dict.eos()) + # if all predictions are finished (i.e., ended with eos), + # pad lprobs to target length with dummy log probs, + # truncate tokens up to this step and break + if step > 0 and is_eos.sum() == is_eos.size(0): + for _ in range(step, target.size(1)): + lprobs.append(dummy_log_probs) + tokens = tokens[:, :step + 1] + break + log_probs, attn_scores = self._decode(tokens[:, :step + 1], + model, encoder_out, incremental_states) + #log_probs[:, dict.pad()] = -math.inf # never select pad + tokens[:, step + 1] = log_probs.argmax(-1) + if step > 0: # deal with finished predictions + # make log_probs uniform if the previous output token is EOS + # and add consecutive EOS to the end of prediction + log_probs[is_eos, :] = -np.log(log_probs.size(1)) + tokens[is_eos, step + 1] = dict.eos() + if step < target.size(1): + lprobs.append(log_probs) + if model.decoder.need_attn: + attn.append(attn_scores) + # bsz x min(tgtlen, maxlen + 1) x vocab_size + lprobs = torch.stack(lprobs, dim=1) + if model.decoder.need_attn: + # bsz x (maxlen + 1) x (length of encoder_out) + attn = torch.stack(attn, dim=1) + # word error stats code starts + if not model.training or (self.num_updates // self.args.print_interval > + (self.num_updates - 1) // self.args.print_interval): + pred = lprobs.argmax(-1).cpu() if model.training else \ + tokens[:, 1:].data.cpu() # bsz x len + + if not model.training: # validation step, compute WER stats with scorer + assert pred.size(0) == target.size(0) + self.scorer.reset() + for i in range(target.size(0)): + utt_id = sample['utt_id'][i] + id = sample['id'].data[i] + #ref_tokens = dict.string(target.data[i]) + ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) + pred_tokens = dict.string(pred.data[i]) + self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) + else: # print a randomly sampled result every print_interval updates + assert pred.size() == target.size() + with data_utils.numpy_seed(self.num_updates): + i = np.random.randint(0, len(sample['id'])) + id = sample['id'].data[i] + length = utils.strip_pad(target.data[i], self.padding_idx).size(0) + #ref_one = Tokenizer.tokens_to_sentence(dict.string(target.data[i]), dict) + ref_one = self.train_tgt_dataset.get_original_text(id, dict) + pred_one = Tokenizer.tokens_to_sentence( + dict.string(pred.data[i][:length]), dict) + print('| sample REF: ' + ref_one) + print('| sample PRD: ' + pred_one) + # word error stats code ends + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = target.view(-1, 1) + non_pad_mask = target.ne(self.padding_idx) + nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] + smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = self.eps / lprobs.size(-1) + loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss + sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), + 'sample_size': sample_size, + } + if not model.training: # do not compute word error in training mode + logging_output['word_error'] = self.scorer.tot_word_error() + logging_output['word_count'] = self.scorer.tot_word_count() + logging_output['char_error'] = self.scorer.tot_char_error() + logging_output['char_count'] = self.scorer.tot_char_count() + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + agg_output = LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs(logging_outputs) + word_error = sum(log.get('word_error', 0) for log in logging_outputs) + word_count = sum(log.get('word_count', 0) for log in logging_outputs) + char_error = sum(log.get('char_error', 0) for log in logging_outputs) + char_count = sum(log.get('char_count', 0) for log in logging_outputs) + if word_count > 0: # model.training == False + agg_output['word_error'] = word_error + agg_output['word_count'] = word_count + if char_count > 0: # model.training == False + agg_output['char_error'] = char_error + agg_output['char_count'] = char_count + return agg_output + + def _decode(self, tokens, model, encoder_out, incremental_states): + with torch.no_grad(): + decoder_out = list(model.decoder(tokens, encoder_out, + incremental_state=incremental_states)) + decoder_out[0] = decoder_out[0][:, -1, :] + attn = decoder_out[1] + if type(attn) is dict: + attn = attn['attn'] + if attn is not None: + if type(attn) is dict: + attn = attn['attn'] + attn = attn[:, -1, :] + probs = model.get_normalized_probs(decoder_out, log_probs=True) + return probs, attn + + def set_valid_tgt_dataset(self, dataset): + self.valid_tgt_dataset = dataset + + def set_num_updates(self, num_updates): + self.num_updates = num_updates diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index c8cc7b80b..03edc59c6 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -57,17 +57,16 @@ def add_args(parser): parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', help='decoder output embedding dimension') parser.add_argument('--attention-type', type=str, metavar='STR', - choices=['bahdanau','luong'], default='bahdanau', + choices=['bahdanau','luong'], help='attention type') parser.add_argument('--attention-dim', type=int, metavar='N', help='attention dimension') - parser.add_argument('--need-attention', default=False, action='store_true', + parser.add_argument('--need-attention', action='store_true', help='need to return attention tensor for the caller') parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion') - parser.add_argument('--share-decoder-input-output-embed', default=False, - action='store_true', + parser.add_argument('--share-decoder-input-output-embed', action='store_true', help='share decoder input and output embeddings') # Granular dropout settings (if not specified these default to --dropout) @@ -137,13 +136,13 @@ def eval_str_nested_list_or_tuple(x, type=int): args.encoder_conv_kernel_sizes, type=int) strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) - in_channels = 1 # hard-coded for now + print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, + task.feat_in_channels)) + assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, - in_channels=in_channels) if not out_channels is None else None + in_channels=task.feat_in_channels) if not out_channels is None else None - print('| input feature dimension: {}'.format(task.feat_dim)) - assert task.feat_dim % in_channels == 0 - rnn_encoder_input_size = task.feat_dim // in_channels + rnn_encoder_input_size = task.feat_dim // task.feat_in_channels if conv_layers is not None: for stride in strides: if isinstance(stride, (list, tuple)): diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 65f8fde64..cd4c34e83 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -79,6 +79,8 @@ def add_args(parser): help='max number of tokens in the target sequence') parser.add_argument('--upsample-primary', default=1, type=int, help='amount to upsample primary dataset') + parser.add_argument('--feat-in-channels', default=1, type=int, metavar='N', + help='feature input channels') # fmt: off @classmethod @@ -114,6 +116,7 @@ def load_pretrained_model(path, dict_path, non_lang_syms=None, def __init__(self, args, dict): super().__init__(args) self.dict = dict + self.feat_in_channels = args.feat_in_channels @classmethod def setup_task(cls, args, **kwargs): @@ -201,10 +204,6 @@ def max_positions(self): """Return the max sentence length allowed by the task.""" return (self.args.max_source_positions, self.args.max_target_positions) - def source_dictionary(self): - """Return the source :class:`~fairseq.data.Dictionary`.""" - return None - @property def target_dictionary(self): """Return the target :class:`~fairseq.data.Dictionary`.""" diff --git a/speech_recognition.py b/speech_recognition.py old mode 100644 new mode 100755 index a87c42bdc..d9c23f8cb --- a/speech_recognition.py +++ b/speech_recognition.py @@ -1,5 +1,6 @@ -#!/usr/bin/env python3 -u -# Copyright (c) 2018-present, Yiming Wang +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# 2018-present, Yiming Wang # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in diff --git a/speech_tools/.gitignore b/speech_tools/.gitignore new file mode 100644 index 000000000..6fc07e5f7 --- /dev/null +++ b/speech_tools/.gitignore @@ -0,0 +1 @@ +kaldi diff --git a/speech_tools/dump.sh b/speech_tools/dump.sh new file mode 100755 index 000000000..8202e1bf5 --- /dev/null +++ b/speech_tools/dump.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Copyright 2017 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +. ./path.sh + +cmd=run.pl +do_delta=false +nj=1 +verbose=0 +compress=true +write_utt2num_frames=true + +. utils/parse_options.sh + +scp=$1 +cvmnark=$2 +logdir=$3 +dumpdir=$4 + +if [ $# != 4 ]; then + echo "Usage: $0 " + exit 1; +fi + +mkdir -p $logdir +mkdir -p $dumpdir + +dumpdir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' ${dumpdir} ${PWD}` + +for n in $(seq $nj); do + # the next command does nothing unless $dumpdir/storage/ exists, see + # utils/create_data_link.pl for more info. + utils/create_data_link.pl ${dumpdir}/feats.${n}.ark +done + +if $write_utt2num_frames; then + write_num_frames_opt="--write-num-frames=ark,t:$dumpdir/utt2num_frames.JOB" +else + write_num_frames_opt= +fi + +# split scp file +split_scps="" +for n in $(seq $nj); do + split_scps="$split_scps $logdir/feats.$n.scp" +done + +utils/split_scp.pl $scp $split_scps || exit 1; + +# dump features +if ${do_delta};then + $cmd JOB=1:$nj $logdir/dump_feature.JOB.log \ + apply-cmvn --norm-vars=true $cvmnark scp:$logdir/feats.JOB.scp ark:- \| \ + add-deltas ark:- ark:- \| \ + copy-feats --compress=$compress --compression-method=2 ${write_num_frames_opt} \ + ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \ + || exit 1 +else + $cmd JOB=1:$nj $logdir/dump_feature.JOB.log \ + apply-cmvn --norm-vars=true $cvmnark scp:$logdir/feats.JOB.scp ark:- \| \ + copy-feats --compress=$compress --compression-method=2 ${write_num_frames_opt} \ + ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \ + || exit 1 +fi + +# concatenate scp files +for n in $(seq $nj); do + cat $dumpdir/feats.$n.scp || exit 1; +done > $dumpdir/feats.scp || exit 1 + +if $write_utt2num_frames; then + for n in $(seq $nj); do + cat $dumpdir/utt2num_frames.$n || exit 1; + done > $dumpdir/utt2num_frames || exit 1 + rm $dumpdir/utt2num_frames.* 2>/dev/null +fi + +# remove temp scps +rm $logdir/feats.*.scp 2>/dev/null +if [ ${verbose} -eq 1 ]; then + echo "Succeeded dumping features for training" +fi diff --git a/speech_tools/parse_options.sh b/speech_tools/parse_options.sh deleted file mode 100755 index 34476fdb3..000000000 --- a/speech_tools/parse_options.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/bin/bash - -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); -# Arnab Ghoshal, Karel Vesely - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# Parse command-line options. -# To be sourced by another script (as in ". parse_options.sh"). -# Option format is: --option-name arg -# and shell variable "option_name" gets set to value "arg." -# The exception is --help, which takes no arguments, but prints the -# $help_message variable (if defined). - - -### -### The --config file options have lower priority to command line -### options, so we need to import them first... -### - -# Now import all the configs specified by command-line, in left-to-right order -for ((argpos=1; argpos<$#; argpos++)); do - if [ "${!argpos}" == "--config" ]; then - argpos_plus1=$((argpos+1)) - config=${!argpos_plus1} - [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 - . $config # source the config file. - fi -done - - -### -### No we process the command line options -### -while true; do - [ -z "${1:-}" ] && break; # break if there are no arguments - case "$1" in - # If the enclosing script is called with --help option, print the help - # message and exit. Scripts should put help messages in $help_message - --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; - else printf "$help_message\n" 1>&2 ; fi; - exit 0 ;; - --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" - exit 1 ;; - # If the first command-line argument begins with "--" (e.g. --foo-bar), - # then work out the variable name as $name, which will equal "foo_bar". - --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; - # Next we test whether the variable in question is undefned-- if so it's - # an invalid option and we die. Note: $0 evaluates to the name of the - # enclosing script. - # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar - # is undefined. We then have to wrap this test inside "eval" because - # foo_bar is itself inside a variable ($name). - eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; - - oldval="`eval echo \\$$name`"; - # Work out whether we seem to be expecting a Boolean argument. - if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then - was_bool=true; - else - was_bool=false; - fi - - # Set the variable to the right value-- the escaped quotes make it work if - # the option had spaces, like --cmd "queue.pl -sync y" - eval $name=\"$2\"; - - # Check that Boolean-valued arguments are really Boolean. - if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then - echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 - exit 1; - fi - shift 2; - ;; - *) break; - esac -done - - -# Check for an empty argument to the --cmd option, which can easily occur as a -# result of scripting errors. -[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; - - -true; # so this script returns exit code 0. diff --git a/speech_train.py b/speech_train.py index a1a312e68..12c10f1f1 100755 --- a/speech_train.py +++ b/speech_train.py @@ -35,6 +35,8 @@ def main(args, init_distributed=False): if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.enabled = False # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) @@ -102,7 +104,9 @@ def main(args, init_distributed=False): train_meter.start() valid_losses, valid_wers = [None], [None] valid_subsets = args.valid_subset.split(',') - while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update: + while lr > args.min_lr and (epoch_itr.epoch < max_epoch or \ + (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and \ + trainer.get_num_updates() < max_update: # train for one epoch train(args, trainer, task, epoch_itr) @@ -342,7 +346,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_wer): if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt') + checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) From d276c824d676c6e85fa4392a35fc353bff6b4b44 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 21 Feb 2019 23:06:08 -0500 Subject: [PATCH 012/119] code adaptation/changes according to the commits from Feb 12, 2019 to Feb 22, 2019 --- fairseq/criterions/cross_entropy_with_wer.py | 18 +-- .../label_smoothed_cross_entropy_with_wer.py | 18 +-- fairseq/speech_recognizer.py | 63 --------- fairseq/tasks/speech_recognition.py | 10 ++ speech_recognition.py | 126 ++++++++---------- 5 files changed, 86 insertions(+), 149 deletions(-) delete mode 100644 fairseq/speech_recognizer.py diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index ee595a5a4..0eacc56cd 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -168,18 +168,18 @@ def aggregate_logging_outputs(logging_outputs): return agg_output def _decode(self, tokens, model, encoder_out, incremental_states): - with torch.no_grad(): - decoder_out = list(model.decoder(tokens, encoder_out, - incremental_state=incremental_states)) - decoder_out[0] = decoder_out[0][:, -1, :] - attn = decoder_out[1] + decoder_out = list(model.decoder(tokens, encoder_out, + incremental_state=incremental_states)) + decoder_out[0] = decoder_out[0][:, -1:, :] + attn = decoder_out[1] + if type(attn) is dict: + attn = attn['attn'] + if attn is not None: if type(attn) is dict: attn = attn['attn'] - if attn is not None: - if type(attn) is dict: - attn = attn['attn'] - attn = attn[:, -1, :] + attn = attn[:, -1, :] probs = model.get_normalized_probs(decoder_out, log_probs=True) + probs = probs[:, -1, :] return probs, attn def set_valid_tgt_dataset(self, dataset): diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index ce53d42fc..b2efb2b05 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -176,18 +176,18 @@ def aggregate_logging_outputs(logging_outputs): return agg_output def _decode(self, tokens, model, encoder_out, incremental_states): - with torch.no_grad(): - decoder_out = list(model.decoder(tokens, encoder_out, - incremental_state=incremental_states)) - decoder_out[0] = decoder_out[0][:, -1, :] - attn = decoder_out[1] + decoder_out = list(model.decoder(tokens, encoder_out, + incremental_state=incremental_states)) + decoder_out[0] = decoder_out[0][:, -1:, :] + attn = decoder_out[1] + if type(attn) is dict: + attn = attn['attn'] + if attn is not None: if type(attn) is dict: attn = attn['attn'] - if attn is not None: - if type(attn) is dict: - attn = attn['attn'] - attn = attn[:, -1, :] + attn = attn[:, -1, :] probs = model.get_normalized_probs(decoder_out, log_probs=True) + probs = probs[:, -1, :] return probs, attn def set_valid_tgt_dataset(self, dataset): diff --git a/fairseq/speech_recognizer.py b/fairseq/speech_recognizer.py deleted file mode 100644 index 3e283d942..000000000 --- a/fairseq/speech_recognizer.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - -import math - -import torch - -from fairseq import utils - -from fairseq.sequence_generator import SequenceGenerator - - -class SpeechRecognizer(SequenceGenerator): - def generate_batched_itr( - self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, - cuda=False, timer=None, prefix_size=0, - ): - """Iterate over a batched dataset and yield individual transcription. - - Args: - maxlen_a/b (int, optional): generate sequences of maximum length - ``ax + b``, where ``x`` is the source sentence length. - cuda (bool, optional): use GPU for generation - timer (StopwatchMeter, optional): time generations - prefix_size (int, optional): prefill the generation with the gold - prefix up to this length. - """ - if maxlen_b is None: - maxlen_b = self.maxlen - - for sample in data_itr: - s = utils.move_to_cuda(sample) if cuda else sample - if 'net_input' not in s: - continue - input = s['net_input'] - # model.forward normally channels prev_output_tokens into the decoder - # separately, but SequenceGenerator directly calls model.encoder - encoder_input = { - k: v for k, v in input.items() - if k != 'prev_output_tokens' - } - srclen = encoder_input['src_tokens'].size(1) - if timer is not None: - timer.start() - with torch.no_grad(): - hypos = self.generate( - encoder_input, - beam_size=beam_size, - maxlen=int(maxlen_a*srclen + maxlen_b), - prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, - ) - if timer is not None: - timer.stop(sum(len(h[0]['tokens']) for h in hypos)) - for i, id in enumerate(s['id'].data): - utt_id = s['utt_id'][i] - # remove padding - ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None - yield id, utt_id, ref, hypos[i] - diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index cd4c34e83..757131dff 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -200,6 +200,16 @@ def load_dataset(self, split, combine=False, **kwargs): max_target_positions=self.args.max_target_positions, ) + def build_generator(self, args): + if args.score_reference: + args.score_reference = False + print('| --score-reference is not applicable to speech recognition,' + ' ignoring it.') + return super().build_generator(args) + + def build_dataset_for_inference(self, src_tokens, src_lengths): + return SpeechDataset(src_tokens, src_lengths) + def max_positions(self): """Return the max sentence length allowed by the task.""" return (self.args.max_source_positions, self.args.max_target_positions) diff --git a/speech_recognition.py b/speech_recognition.py index d9c23f8cb..40d9ff2a6 100755 --- a/speech_recognition.py +++ b/speech_recognition.py @@ -16,7 +16,6 @@ from fairseq import wer, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter, TimeMeter -from fairseq.speech_recognizer import SpeechRecognizer from fairseq.utils import import_user_module from speech_tools.utils import Tokenizer, plot_attention @@ -57,6 +56,8 @@ def main(args): ) if args.fp16: model.half() + if use_cuda: + model.cuda() # Load dataset (possibly sharded) itr = task.get_batch_iterator( @@ -79,77 +80,65 @@ def main(args): print('| The option match_source_len is not applicable to ' 'speech recognition. Ignoring it.') gen_timer = StopwatchMeter() - recognizer = SpeechRecognizer( - models, dict, beam_size=args.beam, minlen=args.min_len, - stop_early=(not args.no_early_stop), - normalize_scores=(not args.unnormalized), - len_penalty=args.lenpen, unk_penalty=args.unkpen, - sampling=args.sampling, sampling_topk=args.sampling_topk, - sampling_temperature=args.sampling_temperature, - diverse_beam_groups=args.diverse_beam_groups, - diverse_beam_strength=args.diverse_beam_strength, - match_source_len=False, no_repeat_ngram_size=args.no_repeat_ngram_size, - ) - - if use_cuda: - recognizer.cuda() + generator = task.build_generator(args) # Generate and compute WER scorer = wer.Scorer(dict, wer_output_filter=args.wer_output_filter) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: - recognitions = recognizer.generate_batched_itr( - t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, - cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size, - ) - - sps_meter = TimeMeter() - for sample_id, utt_id, target_tokens, hypos in recognitions: - # Process input and ground truth - has_target = target_tokens is not None - target_tokens = target_tokens.int().cpu() if has_target else None - - # Retrieve the original sentences - if has_target: - target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) - if not args.quiet: - target_sent = Tokenizer.tokens_to_sentence(target_str, dict, - use_unk_sym=False) - print('T-{}\t{}'.format(utt_id, target_sent)) - - # Process top predictions - for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]): - hypo_str = dict.string(hypo['tokens'].int().cpu(), args.remove_bpe) - if not args.quiet or i == 0: - hypo_sent = Tokenizer.tokens_to_sentence(hypo_str, dict) - - if not args.quiet: - print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) - ''' - print('P-{}\t{}'.format( - utt_id, - ' '.join(map( - lambda x: '{:.4f}'.format(x), - hypo['positional_scores'].tolist(), - )) - )) - ''' - - # Score and obtain attention only the top hypothesis - if i == 0: - # src_len x tgt_len - attention = hypo['attention'].float().cpu() \ - if hypo['attention'] is not None else None - if attention is not None: - save_dir = os.path.join(args.output_dir, 'attn_plots') - os.makedirs(save_dir, exist_ok=True) - plot_attention(attention, hypo_sent, utt_id, save_dir) - scorer.add_prediction(utt_id, hypo_str) - if has_target: - scorer.add_evaluation(utt_id, target_str, hypo_str) - - num_sentences += 1 + wps_meter = TimeMeter() + for sample in t: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if 'net_input' not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample['target'][:, :args.prefix_size] + + gen_timer.start() + hypos = task.inference_step(generator, models, sample, prefix_tokens) + num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample['id'].tolist()): + has_target = sample['target'] is not None + utt_id = sample['utt_id'][i] + + # Retrieve the original sentences + if has_target: + target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) + if not args.quiet: + target_sent = Tokenizer.tokens_to_sentence(target_str, + dict, use_unk_sym=False) + print('T-{}\t{}'.format(utt_id, target_sent)) + + # Process top predictions + for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]): + hypo_str = dict.string(hypo['tokens'].int().cpu(), args.remove_bpe) + if not args.quiet or i == 0: + hypo_sent = Tokenizer.tokens_to_sentence(hypo_str, dict) + + if not args.quiet: + print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) + + # Score and obtain attention only the top hypothesis + if i == 0: + # src_len x tgt_len + attention = hypo['attention'].float().cpu() \ + if hypo['attention'] is not None else None + if attention is not None: + save_dir = os.path.join(args.output_dir, 'attn_plots') + os.makedirs(save_dir, exist_ok=True) + plot_attention(attention, hypo_sent, utt_id, save_dir) + scorer.add_prediction(utt_id, hypo_str) + if has_target: + scorer.add_evaluation(utt_id, target_str, hypo_str) + + wps_meter.update(num_generated_tokens) + t.log({'wps': round(wps_meter.avg)}) + num_sentences += sample['nsentences'] print('| Recognized {} utterances in {:.1f}s ({:.2f} utterances/s)'.format( num_sentences, gen_timer.sum, 1. / gen_timer.avg)) @@ -166,11 +155,12 @@ def main(args): print('| Decoded results saved as ' + f.name) if has_target: + header = ' Recognize {} with beam={}: '.format(args.gen_subset, args.beam) fn = 'wer' with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.wer())) - print('| Recognize {} with beam={}: '.format(args.gen_subset, args.beam) + res) + print('|' + header + res) f.write(res + '\n') print('| WER saved in ' + f.name) @@ -178,7 +168,7 @@ def main(args): with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.cer())) - print('| ' + res) + print('|' + ' ' * len(header) + res) f.write(res + '\n') print('| CER saved in ' + f.name) From c1a54da4d95864b23c5d712cd0f5b7d20f3b4134 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 3 Mar 2019 17:36:45 -0500 Subject: [PATCH 013/119] code adaptation/changes according to the commits from Feb 23, 2019 to Mar 4, 2019 --- examples/asr_wsj/conf/fbank.conf | 2 +- examples/asr_wsj/run.sh | 2 +- fairseq/data/scp_dataset.py | 6 +-- fairseq/data/token_dictionary.py | 19 +++----- fairseq/models/speech_lstm.py | 4 +- fairseq/modules/speech_attention.py | 2 +- fairseq/wer.py | 2 +- speech_tools/utils.py | 2 +- speech_train.py | 67 +++++++++++++++-------------- tests/test_speech_dataset.py | 4 +- tests/test_speech_utils.py | 2 +- 11 files changed, 54 insertions(+), 58 deletions(-) diff --git a/examples/asr_wsj/conf/fbank.conf b/examples/asr_wsj/conf/fbank.conf index 82ac7bd0d..752323586 100644 --- a/examples/asr_wsj/conf/fbank.conf +++ b/examples/asr_wsj/conf/fbank.conf @@ -1,2 +1,2 @@ ---sample-frequency=16000 +--sample-frequency=16000 --num-mel-bins=80 diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 37cdade9b..a09847529 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -81,7 +81,7 @@ train_text=data/$train_set/text if [ ${stage} -le 2 ]; then echo "Stage 2: Dictionary Preparation and Text Tokenization" mkdir -p data/lang - + echo "$0: making a non-linguistic symbol list..." cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "<" > $nlsyms cat $nlsyms diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index d4aadddcd..ce4f4ee25 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -145,7 +145,7 @@ class ScpInMemoryDataset(ScpDataset): def __init__(self, path): super().__init__(path) self.read_data() - + def read_data(self): self.data_offsets = np.append([0], np.cumsum(self.sizes)[:-1]) self.buffer = np.empty((sum(self.sizes), self.feat_dim), @@ -158,7 +158,7 @@ def read_data(self): def filter_and_reorder(self, indices): super().filter_and_reorder(indices) self.read_data() - + def __getitem__(self, i): self.check_index(i) ptx = self.data_offsets[i] @@ -221,7 +221,7 @@ def __getitem__(self, i): def get_original_tokens(self, i): self.check_index(i) return self.tokens_list[i] - + def get_original_text(self, i, dictionary): self.check_index(i) return Tokenizer.tokens_to_sentence(self.tokens_list[i], dictionary, diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index faaabb3a2..de3a3b6bf 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -5,10 +5,10 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -from fairseq.data import Dictionary - import torch +from fairseq.data import Dictionary, data_utils + class TokenDictionary(Dictionary): """A mapping from symbols to consecutive integers""" @@ -40,17 +40,10 @@ def token_string(i): else: return self[i] - if bpe_symbol == 'sentencepiece': - sent = ''.join(token_string(i) for i in tensor if i != self.eos() \ - and i != self.pad()) - sent = sent.replace('\u2581', ' ').strip() - else: - sent = ' '.join(token_string(i) for i in tensor if i != self.eos() \ - and i != self.pad()) - if bpe_symbol is not None and bpe_symbol != 'sentencepiece': - sent = (sent + ' ').replace(bpe_symbol, '').rstrip() - return sent - + sent = ' '.join(token_string(i) for i in tensor if i != self.eos() and \ + i != self.pad()) + return data_utils.process_bpe_symbol(sent, bpe_symbol) + def space(self): """Helper to get index of space symbol""" return self.space_index diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 03edc59c6..667f70fdc 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -252,7 +252,7 @@ def __init__( self.dropout_out = dropout_out self.bidirectional = bidirectional self.hidden_size = hidden_size - + self.lstm = LSTM( input_size=input_size, hidden_size=hidden_size, @@ -440,7 +440,7 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): # hidden state concatenated with context vector becomes the # input to the next layer - input = torch.cat((hidden, context), dim=1) + input = torch.cat((hidden, context), dim=1) input = F.dropout(input, p=self.dropout_out, training=self.training) # save state for next time step diff --git a/fairseq/modules/speech_attention.py b/fairseq/modules/speech_attention.py index f0e083a4c..f04a45d18 100644 --- a/fairseq/modules/speech_attention.py +++ b/fairseq/modules/speech_attention.py @@ -25,7 +25,7 @@ def __init__(self, query_dim, value_dim, embed_dim=None): def reset_parameters(self): pass - + def forward(self, query, value, key_padding_mask=None, state=None): # query: bsz x q_hidden # value: len x bsz x v_hidden diff --git a/fairseq/wer.py b/fairseq/wer.py index bc9782ea7..9eab5665e 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -148,7 +148,7 @@ def print_results(self): for utt_id in self.results: res += utt_id + ' ' + self.results[utt_id] return res - + def print_aligned_results(self): res = '' if self.ordered_utt_list is not None: diff --git a/speech_tools/utils.py b/speech_tools/utils.py index de847008d..7a8fdb944 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -40,7 +40,7 @@ def tokenize(sent, space='', non_lang_syms=None): tokens = [space if token == ' ' else token for token in tokens] return ' '.join(tokens) - + @staticmethod def tokens_to_index_tensor(line, dict, append_eos=True): tokens = line.strip().split() diff --git a/speech_train.py b/speech_train.py index 12c10f1f1..a88c6b1c1 100755 --- a/speech_train.py +++ b/speech_train.py @@ -104,7 +104,7 @@ def main(args, init_distributed=False): train_meter.start() valid_losses, valid_wers = [None], [None] valid_subsets = args.valid_subset.split(',') - while lr > args.min_lr and (epoch_itr.epoch < max_epoch or \ + while lr >= args.min_lr and (epoch_itr.epoch < max_epoch or \ (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and \ trainer.get_num_updates() < max_update: # train for one epoch @@ -127,13 +127,14 @@ def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches - if epoch_itr.epoch <= len(args.update_freq): - update_freq = args.update_freq[epoch_itr.epoch - 1] - else: - update_freq = args.update_freq[-1] + update_freq = args.update_freq[epoch_itr.epoch - 1] \ + if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator - itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus) + itr = epoch_itr.next_epoch_itr( + fix_batches_to_gpus=args.fix_batches_to_gpus, + shuffle=(epoch_itr.epoch >= args.curriculum), + ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', @@ -160,7 +161,7 @@ def train(args, trainer, task, epoch_itr): else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg - progress.log(stats) + progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: @@ -178,7 +179,7 @@ def train(args, trainer, task, epoch_itr): stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg - progress.print(stats) + progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ @@ -191,26 +192,26 @@ def train(args, trainer, task, epoch_itr): def get_training_stats(trainer): stats = collections.OrderedDict() - stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg) + stats['loss'] = trainer.get_meter('train_loss') if trainer.get_meter('train_nll_loss').count > 0: - nll_loss = trainer.get_meter('train_nll_loss').avg - stats['nll_loss'] = '{:.3f}'.format(nll_loss) + nll_loss = trainer.get_meter('train_nll_loss') + stats['nll_loss'] = nll_loss else: - nll_loss = trainer.get_meter('train_loss').avg - stats['ppl'] = get_perplexity(nll_loss) - stats['wps'] = round(trainer.get_meter('wps').avg) - stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg) - stats['wpb'] = round(trainer.get_meter('wpb').avg) - stats['bsz'] = round(trainer.get_meter('bsz').avg) + nll_loss = trainer.get_meter('train_loss') + stats['ppl'] = get_perplexity(nll_loss.avg) + stats['wps'] = trainer.get_meter('wps') + stats['ups'] = trainer.get_meter('ups') + stats['wpb'] = trainer.get_meter('wpb') + stats['bsz'] = trainer.get_meter('bsz') stats['num_updates'] = trainer.get_num_updates() stats['lr'] = trainer.get_lr() - stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg) - stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg) - stats['oom'] = trainer.get_meter('oom').avg + stats['gnorm'] = trainer.get_meter('gnorm') + stats['clip'] = trainer.get_meter('clip') + stats['oom'] = trainer.get_meter('oom') if trainer.get_meter('loss_scale') is not None: - stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg) + stats['loss_scale'] = trainer.get_meter('loss_scale') stats['wall'] = round(trainer.get_meter('wall').elapsed_time) - stats['train_wall'] = round(trainer.get_meter('train_wall').sum) + stats['train_wall'] = trainer.get_meter('train_wall') return stats @@ -259,11 +260,11 @@ def validate(args, trainer, task, epoch_itr, subsets): 'sample_size', 'word_count', 'char_count']: continue if k == 'word_error': - extra_meters['valid_wer'].update( + extra_meters['wer'].update( float(v) / log_output['word_count'] * 100, log_output['word_count']) elif k == 'char_error': - extra_meters['valid_cer'].update( + extra_meters['cer'].update( float(v) / log_output['char_count'] * 100, log_output['char_count']) else: @@ -273,22 +274,24 @@ def validate(args, trainer, task, epoch_itr, subsets): stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg - progress.print(stats) + if hasattr(save_checkpoint, 'best'): + stats['best_wer'] = min(save_checkpoint.best, stats['wer']) + progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats['valid_loss']) - valid_wers.append(stats['valid_wer']) + valid_losses.append(stats['loss'].avg) + valid_wers.append(stats['wer']) return valid_losses, valid_wers def get_valid_stats(trainer): stats = collections.OrderedDict() - stats['valid_loss'] = trainer.get_meter('valid_loss').avg + stats['loss'] = trainer.get_meter('valid_loss') if trainer.get_meter('valid_nll_loss').count > 0: - nll_loss = trainer.get_meter('valid_nll_loss').avg - stats['valid_nll_loss'] = nll_loss + nll_loss = trainer.get_meter('valid_nll_loss') + stats['nll_loss'] = nll_loss else: - nll_loss = trainer.get_meter('valid_loss').avg - stats['valid_ppl'] = get_perplexity(nll_loss) + nll_loss = stats['loss'] + stats['ppl'] = get_perplexity(nll_loss.avg) stats['num_updates'] = trainer.get_num_updates() return stats diff --git a/tests/test_speech_dataset.py b/tests/test_speech_dataset.py index 86decab02..7e1f0bc0e 100644 --- a/tests/test_speech_dataset.py +++ b/tests/test_speech_dataset.py @@ -64,7 +64,7 @@ def generate_text_tokens(test_dir, num=10, seed=0): encoding='utf-8') as f: for i in np.random.permutation(range(num)): utt_id = 'utt_id_' + str(i) - length = np.random.randint(10, 100) + length = np.random.randint(10, 100) tokens = [vocab[np.random.randint(0, len(vocab))] \ for _ in range(length)] if tokens[0] == space: @@ -160,7 +160,7 @@ def test_speech_dataset_cached_no_ordered_prefetch(self): def test_speech_dataset_cached_with_ordered_prefetch(self): self._speech_dataset_helper(all_in_memory=False, ordered_prefetch=True) - + def test_speech_dataset_all_in_memory(self): self._speech_dataset_helper(all_in_memory=True) diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py index 9c3a6774b..cf66c6652 100644 --- a/tests/test_speech_utils.py +++ b/tests/test_speech_utils.py @@ -40,7 +40,7 @@ def generate_text(vocab, oovs=[], non_lang_syms=[], seed=0): isinstance(non_lang_syms, list) np.random.seed(seed) sent_len = np.random.randint(2, 30) - sent = '' + sent = '' for _ in range(sent_len): if len(non_lang_syms) > 0 and np.random.randint(0, 20) == 0: word = non_lang_syms[np.random.randint(0, len(non_lang_syms))] From 47dc107c1314549465d613cc41d8f28057dd7393 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 8 Mar 2019 00:56:13 -0500 Subject: [PATCH 014/119] add lr scheduler that allows to set epoch from which lr starts to decay --- examples/asr_wsj/run.sh | 3 +- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index a09847529..dfb5136ad 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -123,7 +123,7 @@ if [ ${stage} -le 3 ]; then --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ --max-epoch 20 --optimizer "adam" --lr 0.001 --weight-decay 0.0 \ - --lr-scheduler "reduce_lr_on_plateau" --lr-shrink 0.5 --min-lr "1e-8" \ + --lr-scheduler "reduce_lr_on_plateau_v2" --lr-shrink 0.5 --min-lr "1e-4" --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file "checkpoint_last.pt" --save-interval-updates 200 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch "speech_conv_lstm_wsj" --criterion "label_smoothed_cross_entropy_with_wer" --label-smoothing 0.05 \ @@ -131,7 +131,6 @@ if [ ${stage} -le 3 ]; then --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file -exit 0 fi if [ ${stage} -le 4 ]; then diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py new file mode 100644 index 000000000..45259f97f --- /dev/null +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -0,0 +1,36 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch.optim.lr_scheduler + +from . import FairseqLRScheduler, register_lr_scheduler +from .reduce_lr_on_plateau import ReduceLROnPlateau + + +@register_lr_scheduler('reduce_lr_on_plateau_v2') +class ReduceLROnPlateauV2(ReduceLROnPlateau): + """Decay the LR by a factor every time the validation loss plateausi, after start_epoch_to_reduce.""" + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer.optimizer, patience=0, factor=args.lr_shrink, + min_lr=args.min_lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--start-reduce-lr-epoch', default=0, type=int, metavar='N', + help='start to reduce lr from specified epoch') + # fmt: on + + def step(self, epoch, val_loss=None): + if epoch < self.args.start_reduce_lr_epoch: + self.lr_scheduler.last_epoch = epoch + return self.args.lr[0] + return super().step(epoch, val_loss) From eb1d55f182cad80205d164fc7fa0ba8f85328420 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 8 Mar 2019 21:34:16 -0500 Subject: [PATCH 015/119] refactor speech_tools.utils.Tokenizer --- fairseq/criterions/cross_entropy_with_wer.py | 8 +- .../label_smoothed_cross_entropy_with_wer.py | 8 +- fairseq/data/scp_dataset.py | 6 +- fairseq/data/token_dictionary.py | 35 +++++++++ fairseq/wer.py | 7 +- speech_recognition.py | 8 +- speech_tools/text2token.py | 4 +- speech_tools/utils.py | 76 ++++++------------- speech_train.py | 6 +- tests/test_speech_utils.py | 16 ++-- 10 files changed, 88 insertions(+), 86 deletions(-) diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 0eacc56cd..52a21236b 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -14,8 +14,6 @@ from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder -from speech_tools.utils import Tokenizer - from . import FairseqCriterion, register_criterion from .cross_entropy import CrossEntropyCriterion @@ -127,10 +125,10 @@ def forward(self, model, sample, reduce=True): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - #ref_one = Tokenizer.tokens_to_sentence(dict.string(target.data[i]), dict) + #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) ref_one = self.train_tgt_dataset.get_original_text(id, dict) - pred_one = Tokenizer.tokens_to_sentence( - dict.string(pred.data[i][:length]), dict) + pred_one = dict.tokens_to_sentence( + dict.string(pred.data[i][:length])) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index b2efb2b05..57caac238 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -13,8 +13,6 @@ from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder -from speech_tools.utils import Tokenizer - from . import FairseqCriterion, register_criterion from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion @@ -127,10 +125,10 @@ def forward(self, model, sample, reduce=True): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - #ref_one = Tokenizer.tokens_to_sentence(dict.string(target.data[i]), dict) + #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) ref_one = self.train_tgt_dataset.get_original_text(id, dict) - pred_one = Tokenizer.tokens_to_sentence( - dict.string(pred.data[i][:length]), dict) + pred_one = dict.tokens_to_sentence( + dict.string(pred.data[i][:length])) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index ce4f4ee25..bce7bf9c0 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -11,7 +11,7 @@ import torch import speech_tools.kaldi_io as kaldi_io -from speech_tools.utils import Tokenizer + class ScpDataset(torch.utils.data.Dataset): """Loader for TorchNet IndexedDataset""" @@ -187,7 +187,7 @@ def read_text(self, path, dictionary): utt_id, tokens = line.strip().split(None, 1) self.utt_ids.append(utt_id) self.tokens_list.append(tokens) - tensor = Tokenizer.tokens_to_index_tensor(tokens, dictionary) + tensor = dictionary.encode_line(tokens, append_eos=self.append_eos) self.tensor_list.append(tensor) self.sizes.append(len(self.tensor_list[-1])) @@ -224,7 +224,7 @@ def get_original_tokens(self, i): def get_original_text(self, i, dictionary): self.check_index(i) - return Tokenizer.tokens_to_sentence(self.tokens_list[i], dictionary, + return dictionary.tokens_to_sentence(self.tokens_list[i], use_unk_sym=False) def __len__(self): diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index de3a3b6bf..f593a1225 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -7,6 +7,7 @@ import torch +from fairseq.tokenizer import tokenize_line from fairseq.data import Dictionary, data_utils @@ -91,3 +92,37 @@ def dummy_sentence(self, length): t = torch.Tensor(length).uniform_(self.nspecial, len(self)).long() t[-1] = self.eos() return t + + def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=False, + consumer=None, append_eos=True, reverse_order=False): + tokens = line_tokenizer(line) + if reverse_order: + tokens = list(reversed(tokens)) + ntokens = len(tokens) + ids = torch.LongTensor(ntokens + 1 if append_eos else ntokens) + + for i, token in enumerate(tokens): + if add_if_not_exist: + idx = self.add_symbol(token) + else: + idx = self.index(token) + ids[i] = idx + if consumer is not None: + consumer(word, idx) + if append_eos: + ids[ntokens] = self.eos_index + return ids + + def tokens_to_sentence(self, line, line_tokenizer=tokenize_line, use_unk_sym=True): + # use_unk_sym=False when we want to restore original transcripts from + # token sequences, e.g., obtain reference to compute WER + tokens = line_tokenizer(line) + sent = "" + for token in tokens: + if token == self.space_word: + sent += " " + elif use_unk_sym and self.index(token) == self.unk_index: + sent += self.unk_word + elif token != self.pad_word and token != self.eos_word: + sent += token + return sent.strip() diff --git a/fairseq/wer.py b/fairseq/wer.py index 9eab5665e..03a541e94 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -50,7 +50,7 @@ def add_prediction(self, utt_id, pred): if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) - pred_words = speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) + pred_words = self.dict.tokens_to_sentence(pred) assert not utt_id in self.results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.results[utt_id] = pred_words + '\n' @@ -77,9 +77,8 @@ def add_evaluation(self, utt_id, ref, pred): self.char_counter += counter # word level counts - ref_words = speech_utils.Tokenizer.tokens_to_sentence(ref, self.dict, - use_unk_sym=False) - pred_words = speech_utils.Tokenizer.tokens_to_sentence(pred, self.dict) + ref_words = self.dict.tokens_to_sentence(ref, use_unk_sym=False) + pred_words = self.dict.tokens_to_sentence(pred) # filter words according to self.word_filters (support re.sub only) for pattern, repl in self.word_filters: diff --git a/speech_recognition.py b/speech_recognition.py index 40d9ff2a6..d271242a6 100755 --- a/speech_recognition.py +++ b/speech_recognition.py @@ -17,7 +17,7 @@ from fairseq import wer, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.utils import import_user_module -from speech_tools.utils import Tokenizer, plot_attention +from speech_tools.utils import plot_attention def main(args): @@ -110,15 +110,15 @@ def main(args): if has_target: target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) if not args.quiet: - target_sent = Tokenizer.tokens_to_sentence(target_str, - dict, use_unk_sym=False) + target_sent = dict.tokens_to_sentence(target_str, + use_unk_sym=False) print('T-{}\t{}'.format(utt_id, target_sent)) # Process top predictions for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]): hypo_str = dict.string(hypo['tokens'].int().cpu(), args.remove_bpe) if not args.quiet or i == 0: - hypo_sent = Tokenizer.tokens_to_sentence(hypo_str, dict) + hypo_sent = dict.tokens_to_sentence(hypo_str) if not args.quiet: print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) diff --git a/speech_tools/text2token.py b/speech_tools/text2token.py index 1c90e027a..54ddd99e3 100755 --- a/speech_tools/text2token.py +++ b/speech_tools/text2token.py @@ -9,7 +9,7 @@ import argparse import sys -from utils import Tokenizer +from utils import tokenize def get_parser(): @@ -37,7 +37,7 @@ def main(args): with (open(args.text, 'r', encoding='utf-8') if args.text else sys.stdin) as f: for line in f: entry = line.rstrip().split() - tokenized = Tokenizer.tokenize(' '.join(entry[args.skip_ncols:]), + tokenized = tokenize(' '.join(entry[args.skip_ncols:]), space=args.space, non_lang_syms=nls) print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 7a8fdb944..2b6843f51 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -14,59 +14,29 @@ from fairseq.utils import buffered_arange, item -class Tokenizer: - - @staticmethod - def tokenize(sent, space='', non_lang_syms=None): - assert isinstance(sent, str) - sent = ' '.join(sent.strip().split()) - - match_pos = [] - if non_lang_syms is not None: - assert isinstance(non_lang_syms, list) - if len(non_lang_syms) > 0: - prog = re.compile('|'.join(map(re.escape, non_lang_syms))) - matches = prog.finditer(sent) - for match in matches: - match_pos.append([match.start(), match.end()]) - - tokens = [] - i = 0 - for (start_pos, end_pos) in match_pos: - tokens.extend([token for token in sent[i:start_pos]]) - tokens.append(sent[start_pos:end_pos]) - i = end_pos - tokens.extend([token for token in sent[i:]]) - - tokens = [space if token == ' ' else token for token in tokens] - return ' '.join(tokens) - - @staticmethod - def tokens_to_index_tensor(line, dict, append_eos=True): - tokens = line.strip().split() - ntokens = len(tokens) - ids = torch.LongTensor(ntokens + 1 if append_eos else ntokens) - - for i, token in enumerate(tokens): - ids[i] = dict.index(token) - if append_eos: - ids[ntokens] = dict.eos_index - return ids - - @staticmethod - def tokens_to_sentence(line, dict, use_unk_sym=True): - # use_unk_sym=False when we want to restore original transcripts from - # token sequences, e.g., obtain reference to compute WER - tokens = line.strip().split() - sent = "" - for token in tokens: - if token == dict.space_word: - sent += " " - elif use_unk_sym and dict.index(token) == dict.unk(): - sent += dict.unk_word - elif token != dict.pad_word and token != dict.eos_word: - sent += token - return sent.strip() +def tokenize(sent, space='', non_lang_syms=None): + assert isinstance(sent, str) + sent = ' '.join(sent.strip().split()) + + match_pos = [] + if non_lang_syms is not None: + assert isinstance(non_lang_syms, list) + if len(non_lang_syms) > 0: + prog = re.compile('|'.join(map(re.escape, non_lang_syms))) + matches = prog.finditer(sent) + for match in matches: + match_pos.append([match.start(), match.end()]) + + tokens = [] + i = 0 + for (start_pos, end_pos) in match_pos: + tokens.extend([token for token in sent[i:start_pos]]) + tokens.append(sent[start_pos:end_pos]) + i = end_pos + tokens.extend([token for token in sent[i:]]) + + tokens = [space if token == ' ' else token for token in tokens] + return ' '.join(tokens) def collate_frames(values, pad_value=0.0, left_pad=False): """Convert a list of 2d tensor into a padded 3d tensor.""" diff --git a/speech_train.py b/speech_train.py index a88c6b1c1..6f0116c96 100755 --- a/speech_train.py +++ b/speech_train.py @@ -36,7 +36,8 @@ def main(args, init_distributed=False): torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = True - #torch.backends.cudnn.enabled = False + if args.disable_cudnn: + torch.backends.cudnn.enabled = False # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) @@ -413,6 +414,9 @@ def print_options_meaning_changes(args): def cli_main(): parser = options.get_training_parser(default_task='speech_recognition') + parser.add_argument('--disable-cudnn', action='store_true', + help='disable cudnn, which would make the training ' + 'much slower') args = options.parse_args_and_arch(parser) print_options_meaning_changes(args) diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py index cf66c6652..74f075896 100644 --- a/tests/test_speech_utils.py +++ b/tests/test_speech_utils.py @@ -72,13 +72,12 @@ def test_speech_tokenizer(self): for i, sent in enumerate(self.text): print('test sentence {}:'.format(i)) print(sent) - tokens = utils.Tokenizer.tokenize(sent, \ + tokens = utils.tokenize(sent, \ space=self.dict.space_word, non_lang_syms=self.non_lang_syms) - # test :func:`~speech_tools.utils.Tokenizer.tokenize` with - # :func:`~speech_tools.utils.Tokenizer.tokens_to_index_tensor` - tensor = utils.Tokenizer.tokens_to_index_tensor(tokens, self.dict, \ - append_eos=True) + # test :func:`~speech_tools.utils.tokenize` with + # :func:`~TokenDictionary.encode_line` + tensor = self.dict.encode_line(tokens, append_eos=True) reconstructed_tokens = self.dict.string(tensor) expected_tokens = ' '.join( [token if self.dict.index(token) != self.dict.unk() else \ @@ -86,10 +85,9 @@ def test_speech_tokenizer(self): ) self.assertEqual(reconstructed_tokens, expected_tokens) - # test :func:`~speech_tools.utils.Tokenizer.tokenize` with - # :func:`~speech_tools.utils.Tokenizer.tokens_to_sentence` - reconstructed_sent = utils.Tokenizer.tokens_to_sentence(tokens, - self.dict) + # test :func:`~speech_tools.utils.tokenize` with + # :func:`~TokenDictionary.tokens_to_sentence` + reconstructed_sent = self.dict.tokens_to_sentence(tokens) expected_sent = [] words = sent.split(' ') for w in words: From f0930b1912d881439cb8294631934c04a77f0d8d Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 9 Mar 2019 03:36:47 -0500 Subject: [PATCH 016/119] lm training --- examples/asr_wsj/local/find_transcripts.pl | 65 +- examples/asr_wsj/local/flist2scp.pl | 32 +- examples/asr_wsj/local/ndx2flist.pl | 63 +- .../asr_wsj/local/normalize_transcript.pl | 60 +- examples/asr_wsj/local/wsj_data_prep.sh | 215 +----- examples/asr_wsj/local/wsj_format_data.sh | 36 +- examples/asr_wsj/run.sh | 120 +++- fairseq/data/scp_dataset.py | 3 +- fairseq/data/token_dictionary.py | 20 - fairseq/models/speech_lstm.py | 234 +++++-- fairseq/modules/speech_attention.py | 12 +- fairseq/tasks/language_modeling_for_asr.py | 125 ++++ fairseq/tasks/speech_recognition.py | 9 + speech_recognition.py => speech_recognize.py | 28 +- speech_tools/.gitignore | 2 + speech_tools/Makefile | 8 +- speech_tools/kaldi_io.py | 630 ------------------ speech_tools/text2token.py | 5 +- speech_train.py | 1 - tests/test_speech_utils.py | 3 +- 20 files changed, 469 insertions(+), 1202 deletions(-) mode change 100755 => 120000 examples/asr_wsj/local/find_transcripts.pl mode change 100755 => 120000 examples/asr_wsj/local/flist2scp.pl mode change 100755 => 120000 examples/asr_wsj/local/ndx2flist.pl mode change 100755 => 120000 examples/asr_wsj/local/normalize_transcript.pl mode change 100755 => 120000 examples/asr_wsj/local/wsj_data_prep.sh mode change 100755 => 120000 examples/asr_wsj/local/wsj_format_data.sh create mode 100644 fairseq/tasks/language_modeling_for_asr.py rename speech_recognition.py => speech_recognize.py (85%) delete mode 100644 speech_tools/kaldi_io.py diff --git a/examples/asr_wsj/local/find_transcripts.pl b/examples/asr_wsj/local/find_transcripts.pl deleted file mode 100755 index 6429411b8..000000000 --- a/examples/asr_wsj/local/find_transcripts.pl +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2011 Microsoft Corporation - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - - -# This program takes on its standard input a list of utterance -# id's, one for each line. (e.g. 4k0c030a is a an utterance id). -# It takes as -# Extracts from the dot files the transcripts for a given -# dataset (represented by a file list). -# - -@ARGV == 1 || die "find_transcripts.pl dot_files_flist < utterance_ids > transcripts"; -$dot_flist = shift @ARGV; - -open(L, "<$dot_flist") || die "Opening file list of dot files: $dot_flist\n"; -while(){ - chop; - m:\S+/(\w{6})00.dot: || die "Bad line in dot file list: $_"; - $spk = $1; - $spk2dot{$spk} = $_; -} - - - -while(){ - chop; - $uttid = $_; - $uttid =~ m:(\w{6})\w\w: || die "Bad utterance id $_"; - $spk = $1; - if($spk ne $curspk) { - %utt2trans = { }; # Don't keep all the transcripts in memory... - $curspk = $spk; - $dotfile = $spk2dot{$spk}; - defined $dotfile || die "No dot file for speaker $spk\n"; - open(F, "<$dotfile") || die "Error opening dot file $dotfile\n"; - while() { - $_ =~ m:(.+)\((\w{8})\)\s*$: || die "Bad line $_ in dot file $dotfile (line $.)\n"; - $trans = $1; - $utt = $2; - $utt2trans{$utt} = $trans; - } - } - if(!defined $utt2trans{$uttid}) { - print STDERR "No transcript for utterance $uttid (current dot file is $dotfile)\n"; - } else { - print "$uttid $utt2trans{$uttid}\n"; - } -} - - diff --git a/examples/asr_wsj/local/find_transcripts.pl b/examples/asr_wsj/local/find_transcripts.pl new file mode 120000 index 000000000..2455a71d6 --- /dev/null +++ b/examples/asr_wsj/local/find_transcripts.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/wsj/s5/local/find_transcripts.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/flist2scp.pl b/examples/asr_wsj/local/flist2scp.pl deleted file mode 100755 index 234e4add1..000000000 --- a/examples/asr_wsj/local/flist2scp.pl +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2011 Microsoft Corporation - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# takes in a file list with lines like -# /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1 -# and outputs an scp in kaldi format with lines like -# 4k0c030a /mnt/matylda2/data/WSJ1/13-16.1/wsj1/si_dt_20/4k0/4k0c030a.wv1 -# (the first thing is the utterance-id, which is the same as the basename of the file. - - -while(<>){ - m:^\S+/(\w+)\.[wW][vV]1$: || die "Bad line $_"; - $id = $1; - $id =~ tr/A-Z/a-z/; # Necessary because of weirdness on disk 13-16.1 (uppercase filenames) - print "$id $_"; -} - diff --git a/examples/asr_wsj/local/flist2scp.pl b/examples/asr_wsj/local/flist2scp.pl new file mode 120000 index 000000000..e7c6f9da4 --- /dev/null +++ b/examples/asr_wsj/local/flist2scp.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/wsj/s5/local/flist2scp.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/ndx2flist.pl b/examples/asr_wsj/local/ndx2flist.pl deleted file mode 100755 index 48fc3dec1..000000000 --- a/examples/asr_wsj/local/ndx2flist.pl +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2011 Microsoft Corporation - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# This program takes as its standard input an .ndx file from the WSJ corpus that looks -# like this: -#;; File: tr_s_wv1.ndx, updated 04/26/94 -#;; -#;; Index for WSJ0 SI-short Sennheiser training data -#;; Data is read WSJ sentences, Sennheiser mic. -#;; Contains 84 speakers X (~100 utts per speaker MIT/SRI and ~50 utts -#;; per speaker TI) = 7236 utts -#;; -#11_1_1:wsj0/si_tr_s/01i/01ic0201.wv1 -#11_1_1:wsj0/si_tr_s/01i/01ic0202.wv1 -#11_1_1:wsj0/si_tr_s/01i/01ic0203.wv1 - -#and as command-line arguments it takes the names of the WSJ disk locations, e.g.: -#/mnt/matylda2/data/WSJ0/11-1.1 /mnt/matylda2/data/WSJ0/11-10.1 ... etc. -# It outputs a list of absolute pathnames (it does this by replacing e.g. 11_1_1 with -# /mnt/matylda2/data/WSJ0/11-1.1. -# It also does a slight fix because one of the WSJ disks (WSJ1/13-16.1) was distributed with -# uppercase rather than lower case filenames. - -foreach $fn (@ARGV) { - $fn =~ m:.+/([0-9\.\-]+)/?$: || die "Bad command-line argument $fn\n"; - $disk_id=$1; - $disk_id =~ tr/-\./__/; # replace - and . with - so 11-10.1 becomes 11_10_1 - $fn =~ s:/$::; # Remove final slash, just in case it is present. - $disk2fn{$disk_id} = $fn; -} - -while(){ - if(m/^;/){ next; } # Comment. Ignore it. - else { - m/^([0-9_]+):\s*(\S+)$/ || die "Could not parse line $_"; - $disk=$1; - if(!defined $disk2fn{$disk}) { - die "Disk id $disk not found"; - } - $filename = $2; # as a subdirectory of the distributed disk. - if($disk eq "13_16_1" && `hostname` =~ m/fit.vutbr.cz/) { - # The disk 13-16.1 has been uppercased for some reason, on the - # BUT system. This is a fix specifically for that case. - $filename =~ tr/a-z/A-Z/; # This disk contains all uppercase filenames. Why? - } - print "$disk2fn{$disk}/$filename\n"; - } -} diff --git a/examples/asr_wsj/local/ndx2flist.pl b/examples/asr_wsj/local/ndx2flist.pl new file mode 120000 index 000000000..2c868304e --- /dev/null +++ b/examples/asr_wsj/local/ndx2flist.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/wsj/s5/local/ndx2flist.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/normalize_transcript.pl b/examples/asr_wsj/local/normalize_transcript.pl deleted file mode 100755 index 09cee0617..000000000 --- a/examples/asr_wsj/local/normalize_transcript.pl +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env perl -# Copyright 2010-2011 Microsoft Corporation - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# This takes data from the standard input that's unnormalized transcripts in the format -# 4k2c0308 Of course there isn\'t any guarantee the company will keep its hot hand [misc_noise] -# 4k2c030a [loud_breath] And new hardware such as the set of personal computers I\. B\. M\. introduced last week can lead to unexpected changes in the software business [door_slam] -# and outputs normalized transcripts. -# c.f. /mnt/matylda2/data/WSJ0/11-10.1/wsj0/transcrp/doc/dot_spec.doc - -@ARGV == 1 || die "usage: normalize_transcript.pl noise_word < transcript > transcript2"; -$noise_word = shift @ARGV; - -while() { - $_ =~ m:^(\S+) (.+): || die "bad line $_"; - $utt = $1; - $trans = $2; - print "$utt"; - foreach $w (split (" ",$trans)) { - $w =~ tr:a-z:A-Z:; # Upcase everything to match the CMU dictionary. . - $w =~ s:\\::g; # Remove backslashes. We don't need the quoting. - $w =~ s:^\%PERCENT$:PERCENT:; # Normalization for Nov'93 test transcripts. - $w =~ s:^\.POINT$:POINT:; # Normalization for Nov'93 test transcripts. - if($w =~ m:^\[\<\w+\]$: || # E.g. [\]$: || # E.g. [door_slam>], this means a door slammed in the next word. Delete. - $w =~ m:\[\w+/\]$: || # E.g. [phone_ring/], which indicates the start of this phenomenon. - $w =~ m:\[\/\w+]$: || # E.g. [/phone_ring], which indicates the end of this phenomenon. - $w eq "~" || # This is used to indicate truncation of an utterance. Not a word. - $w eq ".") { # "." is used to indicate a pause. Silence is optional anyway so not much - # point including this in the transcript. - next; # we won't print this word. - } elsif($w =~ m:\[\w+\]:) { # Other noises, e.g. [loud_breath]. - print " $noise_word"; - } elsif($w =~ m:^\<([\w\']+)\>$:) { - # e.g. replace with and. (the <> means verbal deletion of a word).. but it's pronounced. - print " $1"; - } elsif($w eq "--DASH") { - print " -DASH"; # This is a common issue; the CMU dictionary has it as -DASH. -# } elsif($w =~ m:(.+)\-DASH$:) { # E.g. INCORPORATED-DASH... seems the DASH gets combined with previous word -# print " $1 -DASH"; - } else { - print " $w"; - } - } - print "\n"; -} diff --git a/examples/asr_wsj/local/normalize_transcript.pl b/examples/asr_wsj/local/normalize_transcript.pl new file mode 120000 index 000000000..975e24acf --- /dev/null +++ b/examples/asr_wsj/local/normalize_transcript.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/wsj/s5/local/normalize_transcript.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_data_prep.sh b/examples/asr_wsj/local/wsj_data_prep.sh deleted file mode 100755 index 04f2f6390..000000000 --- a/examples/asr_wsj/local/wsj_data_prep.sh +++ /dev/null @@ -1,214 +0,0 @@ -#!/bin/bash - -# Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) -# Apache 2.0. - - -if [ $# -le 3 ]; then - echo "Arguments should be a list of WSJ directories, see ../run.sh for example." - exit 1; -fi - - -dir=`pwd`/data/local/data -lmdir=`pwd`/data/local/nist_lm -mkdir -p $dir $lmdir -local=`pwd`/local -utils=`pwd`/utils - -. ./path.sh # Needed for KALDI_ROOT -sph2pipe=$KALDI_ROOT/tools/sph2pipe_v2.5/sph2pipe -if [ ! -x $sph2pipe ]; then - echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; - exit 1; -fi - -if [ -z $IRSTLM ] ; then - export IRSTLM=$KALDI_ROOT/tools/irstlm/ -fi -export PATH=${PATH}:$IRSTLM/bin -if ! command -v prune-lm >/dev/null 2>&1 ; then - echo "$0: Error: the IRSTLM is not available or compiled" >&2 - echo "$0: Error: We used to install it by default, but." >&2 - echo "$0: Error: this is no longer the case." >&2 - echo "$0: Error: To install it, go to $KALDI_ROOT/tools" >&2 - echo "$0: Error: and run extras/install_irstlm.sh" >&2 - exit 1 -fi - -cd $dir -# Make directory of links to the WSJ disks such as 11-13.1. This relies on the command -# line arguments being absolute pathnames. -rm -r links/ 2>/dev/null -mkdir links/ -ln -s $* links - -# Do some basic checks that we have what we expected. -if [ ! -d links/11-13.1 -o ! -d links/13-34.1 -o ! -d links/11-2.1 ]; then - echo "wsj_data_prep.sh: Spot check of command line arguments failed" - echo "Command line arguments must be absolute pathnames to WSJ directories" - echo "with names like 11-13.1." - echo "Note: if you have old-style WSJ distribution," - echo "local/cstr_wsj_data_prep.sh may work instead, see run.sh for example." - exit 1; -fi - -# This version for SI-84 - -cat links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ - $local/ndx2flist.pl $* | sort | \ - grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si84.flist - -nl=`cat train_si84.flist | wc -l` -[ "$nl" -eq 7138 ] || echo "Warning: expected 7138 lines in train_si84.flist, got $nl" - -# This version for SI-284 -cat links/13-34.1/wsj1/doc/indices/si_tr_s.ndx \ - links/11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx | \ - $local/ndx2flist.pl $* | sort | \ - grep -v -i 11-2.1/wsj0/si_tr_s/401 > train_si284.flist - -nl=`cat train_si284.flist | wc -l` -[ "$nl" -eq 37416 ] || echo "Warning: expected 37416 lines in train_si284.flist, got $nl" - -# Now for the test sets. -# links/13-34.1/wsj1/doc/indices/readme.doc -# describes all the different test sets. -# Note: each test-set seems to come in multiple versions depending -# on different vocabulary sizes, verbalized vs. non-verbalized -# pronunciations, etc. We use the largest vocab and non-verbalized -# pronunciations. -# The most normal one seems to be the "baseline 60k test set", which -# is h1_p0. - -# Nov'92 (333 utts) -# These index files have a slightly different format; -# have to add .wv1 -cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_20.ndx | \ - $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ - sort > test_eval92.flist - -# Nov'92 (330 utts, 5k vocab) -cat links/11-13.1/wsj0/doc/indices/test/nvp/si_et_05.ndx | \ - $local/ndx2flist.pl $* | awk '{printf("%s.wv1\n", $1)}' | \ - sort > test_eval92_5k.flist - -# Nov'93: (213 utts) -# Have to replace a wrong disk-id. -cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h1_p0.ndx | \ - sed s/13_32_1/13_33_1/ | \ - $local/ndx2flist.pl $* | sort > test_eval93.flist - -# Nov'93: (213 utts, 5k) -cat links/13-32.1/wsj1/doc/indices/wsj1/eval/h2_p0.ndx | \ - sed s/13_32_1/13_33_1/ | \ - $local/ndx2flist.pl $* | sort > test_eval93_5k.flist - -# Dev-set for Nov'93 (503 utts) -cat links/13-34.1/wsj1/doc/indices/h1_p0.ndx | \ - $local/ndx2flist.pl $* | sort > test_dev93.flist - -# Dev-set for Nov'93 (513 utts, 5k vocab) -cat links/13-34.1/wsj1/doc/indices/h2_p0.ndx | \ - $local/ndx2flist.pl $* | sort > test_dev93_5k.flist - - -# Dev-set Hub 1,2 (503, 913 utterances) - -# Note: the ???'s below match WSJ and SI_DT, or wsj and si_dt. -# Sometimes this gets copied from the CD's with upcasing, don't know -# why (could be older versions of the disks). -find `readlink links/13-16.1`/???1/??_??_20 -print | grep -i ".wv1" | sort > dev_dt_20.flist -find `readlink links/13-16.1`/???1/??_??_05 -print | grep -i ".wv1" | sort > dev_dt_05.flist - - -# Finding the transcript files: -for x in $*; do find -L $x -iname '*.dot'; done > dot_files.flist - -# Convert the transcripts into our format (no normalization yet) -for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do - $local/flist2scp.pl $x.flist | sort > ${x}_sph.scp - cat ${x}_sph.scp | awk '{print $1}' | $local/find_transcripts.pl dot_files.flist > $x.trans1 -done - -# Do some basic normalization steps. At this point we don't remove OOVs-- -# that will be done inside the training scripts, as we'd like to make the -# data-preparation stage independent of the specific lexicon used. -noiseword=""; -for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do - cat $x.trans1 | $local/normalize_transcript.pl $noiseword | sort > $x.txt || exit 1; -done - -# Create scp's with wav's. (the wv1 in the distribution is not really wav, it is sph.) -for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do - awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < ${x}_sph.scp > ${x}_wav.scp -done - -# Make the utt2spk and spk2utt files. -for x in train_si84 train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do - cat ${x}_sph.scp | awk '{print $1}' | perl -ane 'chop; m:^...:; print "$_ $&\n";' > $x.utt2spk - cat $x.utt2spk | $utils/utt2spk_to_spk2utt.pl > $x.spk2utt || exit 1; -done - - -#in case we want to limit lm's on most frequent words, copy lm training word frequency list -cp links/13-32.1/wsj1/doc/lng_modl/vocab/wfl_64.lst $lmdir -chmod u+w $lmdir/*.lst # had weird permissions on source. - -# The 20K vocab, open-vocabulary language model (i.e. the one with UNK), without -# verbalized pronunciations. This is the most common test setup, I understand. - -cp links/13-32.1/wsj1/doc/lng_modl/base_lm/bcb20onp.z $lmdir/lm_bg.arpa.gz || exit 1; -chmod u+w $lmdir/lm_bg.arpa.gz - -# trigram would be: -cat links/13-32.1/wsj1/doc/lng_modl/base_lm/tcb20onp.z | \ - perl -e 'while(<>){ if(m/^\\data\\/){ print; last; } } while(<>){ print; }' | \ - gzip -c -f > $lmdir/lm_tg.arpa.gz || exit 1; - -prune-lm --threshold=1e-7 $lmdir/lm_tg.arpa.gz $lmdir/lm_tgpr.arpa || exit 1; -gzip -f $lmdir/lm_tgpr.arpa || exit 1; - -# repeat for 5k language models -cp links/13-32.1/wsj1/doc/lng_modl/base_lm/bcb05onp.z $lmdir/lm_bg_5k.arpa.gz || exit 1; -chmod u+w $lmdir/lm_bg_5k.arpa.gz - -# trigram would be: !only closed vocabulary here! -cp links/13-32.1/wsj1/doc/lng_modl/base_lm/tcb05cnp.z $lmdir/lm_tg_5k.arpa.gz || exit 1; -chmod u+w $lmdir/lm_tg_5k.arpa.gz -gunzip $lmdir/lm_tg_5k.arpa.gz -tail -n 4328839 $lmdir/lm_tg_5k.arpa | gzip -c -f > $lmdir/lm_tg_5k.arpa.gz -rm $lmdir/lm_tg_5k.arpa - -prune-lm --threshold=1e-7 $lmdir/lm_tg_5k.arpa.gz $lmdir/lm_tgpr_5k.arpa || exit 1; -gzip -f $lmdir/lm_tgpr_5k.arpa || exit 1; - - -if [ ! -f wsj0-train-spkrinfo.txt ] || [ `cat wsj0-train-spkrinfo.txt | wc -l` -ne 134 ]; then - rm wsj0-train-spkrinfo.txt - ! wget https://catalog.ldc.upenn.edu/docs/LDC93S6A/wsj0-train-spkrinfo.txt && \ - echo "Getting wsj0-train-spkrinfo.txt from backup location" && \ - wget --no-check-certificate https://sourceforge.net/projects/kaldi/files/wsj0-train-spkrinfo.txt -fi - -if [ ! -f wsj0-train-spkrinfo.txt ]; then - echo "Could not get the spkrinfo.txt file from LDC website (moved)?" - echo "This is possibly omitted from the training disks; couldn't find it." - echo "Everything else may have worked; we just may be missing gender info" - echo "which is only needed for VTLN-related diagnostics anyway." - exit 1 -fi -# Note: wsj0-train-spkrinfo.txt doesn't seem to be on the disks but the -# LDC put it on the web. Perhaps it was accidentally omitted from the -# disks. - -cat links/11-13.1/wsj0/doc/spkrinfo.txt \ - links/13-32.1/wsj1/doc/evl_spok/spkrinfo.txt \ - links/13-34.1/wsj1/doc/dev_spok/spkrinfo.txt \ - links/13-34.1/wsj1/doc/train/spkrinfo.txt \ - ./wsj0-train-spkrinfo.txt | \ - perl -ane 'tr/A-Z/a-z/; m/^;/ || print;' | \ - awk '{print $1, $2}' | grep -v -- -- | sort | uniq > spk2gender - - -echo "Data preparation succeeded" diff --git a/examples/asr_wsj/local/wsj_data_prep.sh b/examples/asr_wsj/local/wsj_data_prep.sh new file mode 120000 index 000000000..f909e21b5 --- /dev/null +++ b/examples/asr_wsj/local/wsj_data_prep.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/wsj/s5/local/wsj_data_prep.sh \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_format_data.sh b/examples/asr_wsj/local/wsj_format_data.sh deleted file mode 100755 index d567fd1bd..000000000 --- a/examples/asr_wsj/local/wsj_format_data.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -# Copyright 2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) -# 2015 Guoguo Chen -# Apache 2.0 - -# This script takes data prepared in a corpus-dependent way -# in data/local/, and converts it into the "canonical" form, -# in various subdirectories of data/, e.g. data/lang, data/lang_test_ug, -# data/train_si284, data/train_si84, etc. - -# Don't bother doing train_si84 separately (although we have the file lists -# in data/local/) because it's just the first 7138 utterances in train_si284. -# We'll create train_si84 after doing the feature extraction. - -lang_suffix= - -echo "$0 $@" # Print the command line for logging -. utils/parse_options.sh || exit 1; - -. ./path.sh || exit 1; - -echo "Preparing train and test data" -srcdir=data/local/data - -for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do - mkdir -p data/$x - cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1; - cp $srcdir/$x.txt data/$x/text || exit 1; - cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1; - cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1; - utils/filter_scp.pl data/$x/spk2utt $srcdir/spk2gender > data/$x/spk2gender || exit 1; -done - -echo "Succeeded in formatting data." diff --git a/examples/asr_wsj/local/wsj_format_data.sh b/examples/asr_wsj/local/wsj_format_data.sh new file mode 120000 index 000000000..710e8f3a0 --- /dev/null +++ b/examples/asr_wsj/local/wsj_format_data.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/wsj/s5/local/wsj_format_data.sh \ No newline at end of file diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index dfb5136ad..f58bae379 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -11,29 +11,38 @@ set -e -o pipefail stage=0 free_gpu= + +# e2e model related affix= train_set=train_si284 valid_set=test_dev93 test_set=test_eval92 checkpoint=checkpoint_best.pt validate_on_train=false +disable_cudnn=false -dumpdir=data/dump # directory to dump full features -# feature configuration -do_delta=false +# LM related +lm_affix= +lm_checkpoint=checkpoint_best.pt +lm_shallow_fusion=false -# data +# data related +dumpdir=data/dump # directory to dump full features wsj0= wsj1= if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then wsj0=/export/corpora5/LDC/LDC93S6B wsj1=/export/corpora5/LDC/LDC94S13B fi +# feature configuration +do_delta=false + . ./path.sh . ./cmd.sh . ./utils/parse_options.sh +lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} dir=exp/lstm${affix:+_$affix} if [ ${stage} -le 0 ]; then @@ -77,12 +86,13 @@ fi dict=data/lang/${train_set}_units.txt nlsyms=data/lang/non_lang_syms.txt -train_text=data/$train_set/text +lmdatadir=data/lm_text if [ ${stage} -le 2 ]; then echo "Stage 2: Dictionary Preparation and Text Tokenization" mkdir -p data/lang echo "$0: making a non-linguistic symbol list..." + train_text=data/$train_set/text cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "<" > $nlsyms cat $nlsyms @@ -97,45 +107,113 @@ if [ ${stage} -le 2 ]; then wc -l $dict fi done + + echo "$0: preparing text for LM..." + mkdir -p $lmdatadir + for dataset in $train_set $valid_set $test_set; do + token_text=data/$dataset/token_text + cut -f 2- -d" " $token_text > $lmdatadir/$dataset.tokens + done + zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \ + | grep -v "<" | tr "[:lower:]" "[:upper:]" \ + | text2token.py --space "" > $lmdatadir/train_others.tokens + cat $lmdatadir/$train_set.tokens $lmdatadir/train_others.tokens > $lmdatadir/train.tokens +fi + +lmdict=$dict +if [ ${stage} -le 3 ]; then + echo "Stage 3: Text Binarization for LM Training" + lmdict=$dict + mkdir -p $lmdatadir/logs + ${decode_cmd} $lmdatadir/logs/preprocess.log \ + python3 ../../preprocess.py --task language_modeling_for_asr \ + --workers 30 --srcdict $lmdict --only-source \ + --trainpref $lmdatadir/train.tokens \ + --validpref $lmdatadir/$valid_set.tokens \ + --testpref $lmdatadir/$test_set.tokens \ + --destdir $lmdatadir +fi + +if [ ${stage} -le 4 ]; then + echo "Stage 4: LM Training" + valid_subset=valid + mkdir -p $lmdir/logs + log_file=$lmdir/logs/train.log + [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" + [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) + [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ + --task language_modeling_for_asr --dict $lmdict \ + --log-interval 2000 --log-format simple \ + --num-workers 0 --max-tokens 25600 --max-sentences 128 \ + --valid-subset $valid_subset --max-sentences-valid 256 \ + --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ + --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ + --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ + --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ + --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --arch lstm_lm_wsj --dropout 0.1 --criterion cross_entropy \ + --sample-break-mode eos 2>&1 | tee $log_file +fi + +if [ ${stage} -le 5 ]; then + echo "Stage 5: LM Evaluation" + for gen_subset in valid test; do + log_file=$lmdir/logs/evaluation_$gen_subset.log + [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) + [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../eval_lm.py $lmdatadir \ + --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ + --max-tokens 192000 --max-sentences 256 --sample-break-mode eos \ + --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file + done fi train_feat=$train_feat_dir/feats.scp train_token_text=data/$train_set/token_text valid_feat=$valid_feat_dir/feats.scp valid_token_text=data/$valid_set/token_text -if [ ${stage} -le 3 ]; then - echo "Stage 3: Model Training" +if [ ${stage} -le 6 ]; then + echo "Stage 6: Model Training" valid_subset=valid if $validate_on_train; then - valid_subset="$valid_subset train" + valid_subset="$valid_subset,train" fi mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" opts="" + $disable_cudnn && opts="$opts --disable-cudnn" [ -f local/wer_output_filter ] && \ opts="$opts --wer-output-filter local/wer_output_filter" [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ - --log-interval 500 --log-format "simple" --print-training-sample-interval 1000 \ + --log-interval 500 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 \ --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ - --max-epoch 20 --optimizer "adam" --lr 0.001 --weight-decay 0.0 \ - --lr-scheduler "reduce_lr_on_plateau_v2" --lr-shrink 0.5 --min-lr "1e-4" --start-reduce-lr-epoch 11 \ - --save-dir $dir --restore-file "checkpoint_last.pt" --save-interval-updates 200 \ + --max-epoch 20 --optimizer adam --lr 0.001 --weight-decay 0.0 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-4 --start-reduce-lr-epoch 11 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 200 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ - --arch "speech_conv_lstm_wsj" --criterion "label_smoothed_cross_entropy_with_wer" --label-smoothing 0.05 \ + --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer --label-smoothing 0.05 \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ - --dict $dict --non-lang-syms $nlsyms \ + --dict $dict --non-lang-syms $nlsyms --dropout 0.2 \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi -if [ ${stage} -le 4 ]; then - echo "Stage 4: Decoding" +if [ ${stage} -le 7 ]; then + echo "Stage 7: Decoding" opts="" + path=$dir/$checkpoint + decode_affix= + if $lm_shallow_fusion; then + path="$path:$lmdir/$lm_checkpoint" + opts="$opts --lprob-weights 1,0.29" + decode_affix=shallow_fusion + fi [ -f local/wer_output_filter ] && \ opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $valid_set $test_set; do @@ -147,13 +225,13 @@ if [ ${stage} -le 4 ]; then text=data/$dataset/token_text [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; - CUDA_VISIBLE_DEVICES=$free_gpu speech_recognition.py \ - --max-tokens 45000 --max-sentences 32 --num-shards 1 --shard-id 0 \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_recognize.py \ + --max-tokens 45000 --max-sentences 1 --num-shards 1 --shard-id 0 \ --test-feat-files $feat --test-text-files $text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ - --path $dir/$checkpoint --beam 15 --max-len-a 0.5 --max-len-b 0 \ - --lenpen 1.0 --output-dir $dir/decode_$dataset --print-alignment $opts \ - 2>&1 | tee $dir/logs/decode_$dataset.log + --path $path --beam 50 --max-len-a 0.5 --max-len-b 0 --lenpen 1.1 --no-early-stop \ + --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ + --print-alignment 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log done fi diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index bce7bf9c0..45ecde8b6 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -187,7 +187,8 @@ def read_text(self, path, dictionary): utt_id, tokens = line.strip().split(None, 1) self.utt_ids.append(utt_id) self.tokens_list.append(tokens) - tensor = dictionary.encode_line(tokens, append_eos=self.append_eos) + tensor = dictionary.encode_line(tokens, + add_if_not_exist=False, append_eos=self.append_eos).long() self.tensor_list.append(tensor) self.sizes.append(len(self.tensor_list[-1])) diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index f593a1225..c3c1b6dea 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -93,26 +93,6 @@ def dummy_sentence(self, length): t[-1] = self.eos() return t - def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=False, - consumer=None, append_eos=True, reverse_order=False): - tokens = line_tokenizer(line) - if reverse_order: - tokens = list(reversed(tokens)) - ntokens = len(tokens) - ids = torch.LongTensor(ntokens + 1 if append_eos else ntokens) - - for i, token in enumerate(tokens): - if add_if_not_exist: - idx = self.add_symbol(token) - else: - idx = self.index(token) - ids[i] = idx - if consumer is not None: - consumer(word, idx) - if append_eos: - ids[ntokens] = self.eos_index - return ids - def tokens_to_sentence(self, line, line_tokenizer=tokenize_line, use_unk_sym=True): # use_unk_sym=False when we want to restore original transcripts from # token sequences, e.g., obtain reference to compute WER diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 667f70fdc..8f4734ad7 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -12,8 +12,8 @@ from fairseq import options, utils from fairseq.modules import AdaptiveSoftmax, speech_attention from . import ( - FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, - register_model_architecture, + FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, + FairseqLanguageModel, register_model, register_model_architecture, ) from .lstm import AttentionLayer, Embedding, LSTM, LSTMCell, Linear @@ -23,8 +23,11 @@ @register_model('speech_lstm') class SpeechLSTMModel(FairseqModel): - def __init__(self, encoder, decoder): + def __init__(self, encoder, decoder, pretrained_lm=None): super().__init__(encoder, decoder) + self.pretrained_lm = pretrained_lm + if pretrained_lm is not None: + assert isinstance(self.pretrained_lm, FairseqDecoder) @staticmethod def add_args(parser): @@ -32,11 +35,11 @@ def add_args(parser): # fmt: off parser.add_argument('--dropout', type=float, metavar='D', help='dropout probability') - parser.add_argument('--encoder-conv-channels', type=str, metavar='STR', + parser.add_argument('--encoder-conv-channels', type=str, metavar='EXPR', help='list of encoder convolution\'s out channels') - parser.add_argument('--encoder-conv-kernel-sizes', type=str, metavar='STR', + parser.add_argument('--encoder-conv-kernel-sizes', type=str, metavar='EXPR', help='list of encoder convolution\'s kernel sizes') - parser.add_argument('--encoder-conv-strides', type=str, metavar='STR', + parser.add_argument('--encoder-conv-strides', type=str, metavar='EXPR', help='list of encoder convolution\'s strides') parser.add_argument('--encoder-rnn-hidden-size', type=int, metavar='N', help='encoder rnn\'s hidden size') @@ -68,6 +71,9 @@ def add_args(parser): 'Must be used with adaptive_loss criterion') parser.add_argument('--share-decoder-input-output-embed', action='store_true', help='share decoder input and output embeddings') + parser.add_argument('--pretrained-lm-checkpoint', type=str, metavar='STR', + help='path to load checkpoint from pretrained language model(LM), ' + 'which will be present and kept fixed during training.') # Granular dropout settings (if not specified these default to --dropout) parser.add_argument('--encoder-rnn-dropout-in', type=float, metavar='D', @@ -106,7 +112,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.share_decoder_input_output_embed and ( args.decoder_embed_dim != args.decoder_out_embed_dim): raise ValueError( - '--share-decoder-input-output-embeddings requires ' + '--share-decoder-input-output-embed requires ' '--decoder-embed-dim to match --decoder-out-embed-dim' ) @@ -182,7 +188,115 @@ def eval_str_nested_list_or_tuple(x, type=int): if args.criterion == 'adaptive_loss' else None ), ) - return cls(encoder, decoder) + pretrained_lm = None + if args.pretrained_lm_checkpoint: + print('| loading pretrained LM from {}'.format(args.pretrained_lm_checkpoint)) + pretrained_lm = utils.load_ensemble_for_inference( + args.pretrained_lm_checkpoint, task)[0][0] + pretrained_lm.make_generation_fast_() + # freeze pretrained model + for param in pretrained_lm.parameters(): + param.requires_grad = False + return cls(encoder, decoder, pretrained_lm) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder.max_positions(), + self.decoder.max_positions() if self.pretrained_lm is None else \ + min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) + ) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() if self.pretrained_lm is None else \ + min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) + + +@register_model('lstm_lm') +class LSTMLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-freeze-embed', action='store_true', + help='freeze decoder embeddings') + parser.add_argument('--decoder-hidden-size', type=int, metavar='N', + help='decoder hidden size') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='number of decoder layers') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + parser.add_argument('--share-embed', action='store_true', + help='share input and output embeddings') + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument('--decoder-dropout-in', type=float, metavar='D', + help='dropout probability for decoder input embedding') + parser.add_argument('--decoder-dropout-out', type=float, metavar='D', + help='dropout probability for decoder output') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_lm_architecture(args) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + # separate decoder input embeddings + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, + task.target_dictionary, + args.decoder_embed_dim + ) + # one last double check of parameter combinations + if args.share_embed and ( + args.decoder_embed_dim != args.decoder_out_embed_dim): + raise ValueError( + '--share-embed requires ' + '--decoder-embed-dim to match --decoder-out-embed-dim' + ) + + if args.decoder_freeze_embed: + pretrained_decoder_embed.weight.requires_grad = False + + decoder = SpeechLSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + pretrained_embed=pretrained_decoder_embed, + share_input_output_embed=args.share_embed, + adaptive_softmax_cutoff=( + options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == 'adaptive_loss' else None + ), + ) + return LSTMLanguageModel(decoder) class ConvBNReLU(nn.Module): @@ -257,7 +371,7 @@ def __init__( input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, - dropout=self.dropout_out if num_layers > 1 else 0., + #dropout=self.dropout_out if num_layers > 1 else 0., bidirectional=bidirectional, ) self.left_pad = left_pad @@ -346,16 +460,19 @@ class SpeechLSTMDecoder(FairseqIncrementalDecoder): """LSTM decoder.""" def __init__( self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, - encoder_output_units=512, attn_type='bahdanau', attn_dim=256, - need_attn=False, pretrained_embed=None, share_input_output_embed=False, - adaptive_softmax_cutoff=None, + num_layers=1, dropout_in=0.1, dropout_out=0.1, encoder_output_units=0, + attn_type=None, attn_dim=0, need_attn=False, pretrained_embed=None, + share_input_output_embed=False, adaptive_softmax_cutoff=None, ): super().__init__(dictionary) self.dropout_in = dropout_in self.dropout_out = dropout_out self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed + if attn_type is None or attn_type.lower() == 'none': + # no attention, no encoder output needed (language model case) + need_attn = False + encoder_output_units = 0 self.need_attn = need_attn self.adaptive_softmax = None @@ -375,10 +492,12 @@ def __init__( ) for layer in range(num_layers) ]) - if attn_type == 'bahdanau': + if attn_type is None or attn_type.lower() == 'none': + self.attention = None + elif attn_type.lower() == 'bahdanau': self.attention = speech_attention.BahdanauAttention(hidden_size, encoder_output_units, attn_dim) - elif attn_type == 'luong': + elif attn_type.lower() == 'luong': self.attention = speech_attention.LuongAttention(hidden_size, encoder_output_units) else: @@ -392,18 +511,19 @@ def __init__( elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) - def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): - encoder_out = encoder_out_dict['encoder_out'] - encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] + def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None): + if self.attention is not None: + assert encoder_out_dict is not None + encoder_out = encoder_out_dict['encoder_out'] + encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] + # get outputs from encoder + encoder_outs = encoder_out[0] + srclen = encoder_outs.size(0) if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() - # get outputs from encoder - encoder_outs = encoder_out[0] - srclen = encoder_outs.size(0) - # embed tokens x = self.embed_tokens(prev_output_tokens) x = F.dropout(x, p=self.dropout_in, training=self.training) @@ -421,26 +541,32 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): for i in range(num_layers)] prev_cells = [x.new_zeros(bsz, self.hidden_size) \ for i in range(num_layers)] - input_feed = x.new_zeros(bsz, self.encoder_output_units) + input_feed = x.new_zeros(bsz, self.encoder_output_units) \ + if self.attention is not None else None - attn_scores = x.new_zeros(srclen, seqlen, bsz) + if self.attention is not None: + attn_scores = x.new_zeros(srclen, seqlen, bsz) outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step - input = torch.cat((x[j, :, :], input_feed), dim=1) + input = torch.cat((x[j, :, :], input_feed), dim=1) \ + if input_feed is not None else x[j, :, :] for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) # compute and apply attention using the 1st layer's hidden state - if i == 0: - context, attn_scores[:, j, :], _ = self.attention(hidden, - encoder_outs, encoder_padding_mask) - - # hidden state concatenated with context vector becomes the - # input to the next layer - input = torch.cat((hidden, context), dim=1) + if self.attention is not None: + if i == 0: + context, attn_scores[:, j, :], _ = self.attention(hidden, + encoder_outs, encoder_padding_mask) + + # hidden state concatenated with context vector becomes the + # input to the next layer + input = torch.cat((hidden, context), dim=1) + else: + input = hidden input = F.dropout(input, p=self.dropout_out, training=self.training) # save state for next time step @@ -448,7 +574,7 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): prev_cells[i] = cell # input feeding - input_feed = context + input_feed = context if self.attention is not None else None # save final output outs.append(input) @@ -467,7 +593,7 @@ def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): x = x.transpose(1, 0) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen - if not self.training and self.need_attn: + if not self.training and self.attention is not None and self.need_attn: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None @@ -492,7 +618,7 @@ def reorder_incremental_state(self, incremental_state, new_order): def reorder_state(state): if isinstance(state, list): return [reorder_state(state_i) for state_i in state] - return state.index_select(0, new_order) + return state.index_select(0, new_order) if state is not None else None new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) @@ -527,6 +653,26 @@ def Convolution2d(in_channels, out_channels, kernel_size, stride): return m +@register_model_architecture('lstm_lm', 'lstm_lm') +def base_lm_architecture(args): + args.dropout = getattr(args, 'dropout', 0.2) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) + args.decoder_hidden_size = getattr(args, 'decoder_hiden_size', 650) + args.decoder_layers = getattr(args, 'decoder_layers', 2) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 650) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) + args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) + args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) + args.share_embed = getattr(args, 'share_embed', False) + + +@register_model_architecture('lstm_lm', 'lstm_lm_wsj') +def lstm_lm_wsj(args): + base_lm_architecture(args) + + @register_model_architecture('speech_lstm', 'speech_lstm') def base_architecture(args): args.dropout = getattr(args, 'dropout', 0.1) @@ -539,10 +685,10 @@ def base_architecture(args): args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 320) args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) args.encoder_rnn_bidirectional = getattr(args, 'encoder_rnn_bidirectional', True) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 320) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 320) args.decoder_layers = getattr(args, 'decoder_layers', 3) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 960) args.attention_type = getattr(args, 'attention_type', 'bahdanau') @@ -554,17 +700,11 @@ def base_architecture(args): args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) + args.pretrained_lm_checkpoint = getattr(args, 'pretrained_lm_checkpoint', None) @register_model_architecture('speech_lstm', 'speech_conv_lstm_wsj') def conv_lstm_wsj(args): - ''' - args.dropout = getattr(args, 'dropout', 0.1) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) - args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0) - args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - ''' + args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 512) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 512) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1536) base_architecture(args) diff --git a/fairseq/modules/speech_attention.py b/fairseq/modules/speech_attention.py index f04a45d18..d8d4e5524 100644 --- a/fairseq/modules/speech_attention.py +++ b/fairseq/modules/speech_attention.py @@ -9,7 +9,6 @@ import torch from torch import nn from torch.nn import Parameter -import torch.nn.functional as F from fairseq import utils @@ -23,6 +22,11 @@ def __init__(self, query_dim, value_dim, embed_dim=None): self.value_dim = value_dim self.embed_dim = embed_dim + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + def reset_parameters(self): pass @@ -73,7 +77,8 @@ def forward(self, query, value, key_padding_mask=None, state=None): key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back - attn_scores = F.softmax(attn_scores, dim=0) # len x bsz + attn_scores = utils.softmax(attn_scores, dim=0, + onnx_trace=self.onnx_trace).type_as(attn_scores) # len x bsz # sum weighted value. context: bsz x value_dim context = (attn_scores.unsqueeze(2) * value).sum(dim=0) @@ -112,7 +117,8 @@ def forward(self, query, value, key_padding_mask=None, state=None): key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back - attn_scores = F.softmax(attn_scores, dim=0) # len x bsz + attn_scores = utils.softmax(attn_scores, dim=0, + onnx_trace=self.onnx_trace).type_as(attn_scores) # len x bsz # sum weighted value. context: bsz x value_dim context = (attn_scores.unsqueeze(2) * value).sum(dim=0) diff --git a/fairseq/tasks/language_modeling_for_asr.py b/fairseq/tasks/language_modeling_for_asr.py new file mode 100644 index 000000000..0745e21ef --- /dev/null +++ b/fairseq/tasks/language_modeling_for_asr.py @@ -0,0 +1,125 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch + +import os + +from fairseq import tokenizer +from fairseq.data import TokenDictionary + +from .language_modeling import LanguageModelingTask + +from . import register_task + + +@register_task('language_modeling_for_asr') +class LanguageModelingForASRTask(LanguageModelingTask): + """ + Train a language model. + + Args: + dictionary (~fairseq.data.TokenDictionary): the dictionary for the input of + the language model + output_dictionary (~fairseq.data.TokenDictionary): the dictionary for the + output of the language model. In most cases it will be the same as + *dictionary*, but could possibly be a more limited version of the + dictionary (if ``--output-dictionary-size`` is used). + targets (List[str]): list of the target types that the language model + should predict. Can be one of "self", "future", and "past". + Defaults to "future". + + .. note:: + + The language modeling task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate`, :mod:`fairseq-interactive` and + :mod:`fairseq-eval-lm`. + + The language modeling task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.language_modeling_for_asr_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + LanguageModelingTask.add_args(parser) + parser.add_argument('--dict', default=None, type=str, + help='path to the dictionary') + # fmt: on + + def __init__(self, args, dictionary, output_dictionary, targets=None): + super().__init__(args, dictionary, output_dictionary, targets=targets) + torch.backends.cudnn.deterministic = True + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + return TokenDictionary.load(filename) + + @classmethod + def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): + """Build the dictionary + + Args: + filenames (list): list of filenames + workers (int): number of concurrent workers + threshold (int): defines the minimum word count + nwords (int): defines the total number of words in the final dictionary, + including special symbols + padding_factor (int): can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + d = TokenDictionary() + for filename in filenames: + TokenDictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dictionary = None + output_dictionary = None + if args.data: + dict_path = os.path.join(args.data, 'dict.txt') if args.dict is None \ + else args.dict + dictionary = TokenDictionary.load(dict_path) + print('| dictionary: {} types'.format(len(dictionary))) + output_dictionary = dictionary + if args.output_dictionary_size >= 0: + output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size) + + # upgrade old checkpoints + if hasattr(args, 'exclude_self_target'): + args.self_target = not args.exclude_self_target + + targets = [] + if getattr(args, 'self_target', False): + targets.append('self') + if getattr(args, 'future_target', False): + targets.append('future') + if getattr(args, 'past_target', False): + targets.append('past') + if len(targets) == 0: + # standard language modeling + targets = ['future'] + + return cls(args, dictionary, output_dictionary, targets=targets) diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 757131dff..f21e3257f 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -5,6 +5,8 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import torch + import itertools import os import re @@ -117,6 +119,7 @@ def __init__(self, args, dict): super().__init__(args) self.dict = dict self.feat_in_channels = args.feat_in_channels + torch.backends.cudnn.deterministic = True @classmethod def setup_task(cls, args, **kwargs): @@ -210,6 +213,12 @@ def build_generator(self, args): def build_dataset_for_inference(self, src_tokens, src_lengths): return SpeechDataset(src_tokens, src_lengths) + def inference_step(self, generator, models, sample, prefix_tokens=None, + lprob_weights=None): + with torch.no_grad(): + return generator.generate(models, sample, prefix_tokens=prefix_tokens, + lprob_weights=lprob_weights) + def max_positions(self): """Return the max sentence length allowed by the task.""" return (self.args.max_source_positions, self.args.max_target_positions) diff --git a/speech_recognition.py b/speech_recognize.py similarity index 85% rename from speech_recognition.py rename to speech_recognize.py index d271242a6..4ac27b07b 100755 --- a/speech_recognition.py +++ b/speech_recognize.py @@ -47,6 +47,8 @@ def main(args): models, _model_args = utils.load_ensemble_for_inference( args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), ) + if args.lprob_weights is not None: + print('| using model ensemble with lprob-weights={}'.format(str(args.lprob_weights))) # Optimize ensemble for generation for model in models: @@ -66,7 +68,8 @@ def main(args): max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), - *[model.max_positions() for model in models] + *[model.max_positions() if hasattr(model, 'encoder') \ + else (None, model.max_positions()) for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, @@ -98,7 +101,8 @@ def main(args): prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens) + hypos = task.inference_step(generator, models, sample, prefix_tokens, + lprob_weights=args.lprob_weights) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) @@ -129,7 +133,7 @@ def main(args): attention = hypo['attention'].float().cpu() \ if hypo['attention'] is not None else None if attention is not None: - save_dir = os.path.join(args.output_dir, 'attn_plots') + save_dir = os.path.join(args.results_path, 'attn_plots') os.makedirs(save_dir, exist_ok=True) plot_attention(attention, hypo_sent, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str) @@ -147,17 +151,17 @@ def main(args): scorer.add_ordered_utt_list(*args.test_text_files) - os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.results_path, exist_ok=True) fn = 'decoded_results.txt' - with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: + with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_results()) print('| Decoded results saved as ' + f.name) if has_target: header = ' Recognize {} with beam={}: '.format(args.gen_subset, args.beam) fn = 'wer' - with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: + with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.wer())) print('|' + header + res) @@ -165,7 +169,7 @@ def main(args): print('| WER saved in ' + f.name) fn = 'cer' - with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: + with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.cer())) print('|' + ' ' * len(header) + res) @@ -173,9 +177,10 @@ def main(args): print('| CER saved in ' + f.name) fn = 'aligned_results.txt' - with open(os.path.join(args.output_dir, fn), 'w', encoding='utf-8') as f: + with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_aligned_results()) print('| Aligned results saved as ' + f.name) + return scorer def print_options_meaning_changes(args): @@ -189,9 +194,12 @@ def print_options_meaning_changes(args): def cli_main(): parser = options.get_generation_parser(default_task='speech_recognition') - parser.add_argument('--output-dir', metavar='DIR', required=True, - help='path to output results') + parser.add_argument('--lprob-weights', default=None, type=options.eval_str_list, + metavar='W_1,W_2,...,W_N', + help='model ensemble weights in log-prob space, the same' + 'length as number of models specified in --path') args = options.parse_args_and_arch(parser) + assert args.results_path is not None, 'please specify --results-path' print_options_meaning_changes(args) main(args) diff --git a/speech_tools/.gitignore b/speech_tools/.gitignore index 6fc07e5f7..77acdc259 100644 --- a/speech_tools/.gitignore +++ b/speech_tools/.gitignore @@ -1 +1,3 @@ kaldi +kaldi-io-for-python +kaldi_io.py diff --git a/speech_tools/Makefile b/speech_tools/Makefile index bc98fca9e..a9f4035b0 100644 --- a/speech_tools/Makefile +++ b/speech_tools/Makefile @@ -2,7 +2,11 @@ KALDI = .PHONY: all clean -all: kaldi +all: kaldi kaldi-io-for-python + +kaldi-io-for-python: + git clone https://github.com/vesis84/kaldi-io-for-python.git + ln -nfs kaldi-io-for-python/kaldi_io/kaldi_io.py kaldi_io.py ifneq ($(strip $(KALDI)),) kaldi: @@ -15,4 +19,4 @@ kaldi: endif clean: - rm -fr kaldi + rm -fr kaldi kaldi-io-for-python kaldi_io.py diff --git a/speech_tools/kaldi_io.py b/speech_tools/kaldi_io.py deleted file mode 100644 index 7cc7a1dec..000000000 --- a/speech_tools/kaldi_io.py +++ /dev/null @@ -1,630 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2014-2016 Brno University of Technology (author: Karel Vesely) -# Licensed under the Apache License, Version 2.0 (the "License") - -import numpy as np -import sys, os, re, gzip, struct - -################################################# -# Adding kaldi tools to shell path, - -# Select kaldi, -if not 'KALDI_ROOT' in os.environ: - # Default! To change run python with 'export KALDI_ROOT=/some_dir python' - os.environ['KALDI_ROOT']='/mnt/matylda5/iveselyk/Tools/kaldi-trunk' - -# Add kaldi tools to path, -os.environ['PATH'] = os.popen('echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/').readline().strip() + ':' + os.environ['PATH'] - - -################################################# -# Define all custom exceptions, -class UnsupportedDataType(Exception): pass -class UnknownVectorHeader(Exception): pass -class UnknownMatrixHeader(Exception): pass - -class BadSampleSize(Exception): pass -class BadInputFormat(Exception): pass - -class SubprocessFailed(Exception): pass - -################################################# -# Data-type independent helper functions, - -def open_or_fd(file, mode='rb'): - """ fd = open_or_fd(file) - Open file, gzipped file, pipe, or forward the file-descriptor. - Eventually seeks in the 'file' argument contains ':offset' suffix. - """ - offset = None - try: - # strip 'ark:' prefix from r{x,w}filename (optional), - if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:', file): - (prefix,file) = file.split(':',1) - # separate offset from filename (optional), - if re.search(':[0-9]+$', file): - (file,offset) = file.rsplit(':',1) - # input pipe? - if file[-1] == '|': - fd = popen(file[:-1], 'rb') # custom, - # output pipe? - elif file[0] == '|': - fd = popen(file[1:], 'wb') # custom, - # is it gzipped? - elif file.split('.')[-1] == 'gz': - fd = gzip.open(file, mode) - # a normal file... - else: - fd = open(file, mode) - except TypeError: - # 'file' is opened file descriptor, - fd = file - # Eventually seek to offset, - if offset != None: fd.seek(int(offset)) - return fd - -# based on '/usr/local/lib/python3.4/os.py' -def popen(cmd, mode="rb"): - if not isinstance(cmd, str): - raise TypeError("invalid cmd type (%s, expected string)" % type(cmd)) - - import subprocess, io, threading - - # cleanup function for subprocesses, - def cleanup(proc, cmd): - ret = proc.wait() - if ret > 0: - raise SubprocessFailed('cmd %s returned %d !' % (cmd,ret)) - return - - # text-mode, - if mode == "r": - proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return io.TextIOWrapper(proc.stdout) - elif mode == "w": - proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return io.TextIOWrapper(proc.stdin) - # binary, - elif mode == "rb": - proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return proc.stdout - elif mode == "wb": - proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) - threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, - return proc.stdin - # sanity, - else: - raise ValueError("invalid mode %s" % mode) - - -def read_key(fd): - """ [key] = read_key(fd) - Read the utterance-key from the opened ark/stream descriptor 'fd'. - """ - key = '' - while 1: - char = fd.read(1).decode("latin1") - if char == '' : break - if char == ' ' : break - key += char - key = key.strip() - if key == '': return None # end of file, - assert(re.match('^\S+$',key) != None) # check format (no whitespace!) - return key - - -################################################# -# Integer vectors (alignments, ...), - -def read_ali_ark(file_or_fd): - """ Alias to 'read_vec_int_ark()' """ - return read_vec_int_ark(file_or_fd) - -def read_vec_int_ark(file_or_fd): - """ generator(key,vec) = read_vec_int_ark(file_or_fd) - Create generator of (key,vector) tuples, which reads from the ark file/stream. - file_or_fd : ark, gzipped ark, pipe or opened file descriptor. - - Read ark to a 'dictionary': - d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) } - """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - ali = read_vec_int(fd) - yield key, ali - key = read_key(fd) - finally: - if fd is not file_or_fd: fd.close() - -def read_vec_int(file_or_fd): - """ [int-vec] = read_vec_int(file_or_fd) - Read kaldi integer vector, ascii or binary input, - """ - fd = open_or_fd(file_or_fd) - binary = fd.read(2).decode() - if binary == '\0B': # binary flag - assert(fd.read(1).decode() == '\4'); # int-size - vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim - # Elements from int32 vector are sored in tuples: (sizeof(int32), value), - vec = np.frombuffer(fd.read(vec_size*5), dtype=[('size','int8'),('value','int32')], count=vec_size) - assert(vec[0]['size'] == 4) # int32 size, - ans = vec[:]['value'] # values are in 2nd column, - else: # ascii, - arr = (binary + fd.readline().decode()).strip().split() - try: - arr.remove('['); arr.remove(']') # optionally - except ValueError: - pass - ans = np.array(arr, dtype=int) - if fd is not file_or_fd : fd.close() # cleanup - return ans - -# Writing, -def write_vec_int(file_or_fd, v, key=''): - """ write_vec_int(f, v, key='') - Write a binary kaldi integer vector to filename or stream. - Arguments: - file_or_fd : filename or opened file descriptor for writing, - v : the vector to be stored, - key (optional) : used for writing ark-file, the utterance-id gets written before the vector. - - Example of writing single vector: - kaldi_io.write_vec_int(filename, vec) - - Example of writing arkfile: - with open(ark_file,'w') as f: - for key,vec in dict.iteritems(): - kaldi_io.write_vec_flt(f, vec, key=key) - """ - fd = open_or_fd(file_or_fd, mode='wb') - if sys.version_info[0] == 3: assert(fd.mode == 'wb') - try: - if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), - fd.write('\0B'.encode()) # we write binary! - # dim, - fd.write('\4'.encode()) # int32 type, - fd.write(struct.pack(np.dtype('int32').char, v.shape[0])) - # data, - for i in range(len(v)): - fd.write('\4'.encode()) # int32 type, - fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary, - finally: - if fd is not file_or_fd : fd.close() - - -################################################# -# Float vectors (confidences, ivectors, ...), - -# Reading, -def read_vec_flt_scp(file_or_fd): - """ generator(key,mat) = read_vec_flt_scp(file_or_fd) - Returns generator of (key,vector) tuples, read according to kaldi scp. - file_or_fd : scp, gzipped scp, pipe or opened file descriptor. - - Iterate the scp: - for key,vec in kaldi_io.read_vec_flt_scp(file): - ... - - Read scp to a 'dictionary': - d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } - """ - fd = open_or_fd(file_or_fd) - try: - for line in fd: - (key,rxfile) = line.decode().split(' ') - vec = read_vec_flt(rxfile) - yield key, vec - finally: - if fd is not file_or_fd : fd.close() - -def read_vec_flt_ark(file_or_fd): - """ generator(key,vec) = read_vec_flt_ark(file_or_fd) - Create generator of (key,vector) tuples, reading from an ark file/stream. - file_or_fd : ark, gzipped ark, pipe or opened file descriptor. - - Read ark to a 'dictionary': - d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) } - """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - ali = read_vec_flt(fd) - yield key, ali - key = read_key(fd) - finally: - if fd is not file_or_fd: fd.close() - -def read_vec_flt(file_or_fd): - """ [flt-vec] = read_vec_flt(file_or_fd) - Read kaldi float vector, ascii or binary input, - """ - fd = open_or_fd(file_or_fd) - binary = fd.read(2).decode() - if binary == '\0B': # binary flag - return _read_vec_flt_binary(fd) - else: # ascii, - arr = (binary + fd.readline().decode()).strip().split() - try: - arr.remove('['); arr.remove(']') # optionally - except ValueError: - pass - ans = np.array(arr, dtype=float) - if fd is not file_or_fd : fd.close() # cleanup - return ans - -def _read_vec_flt_binary(fd): - header = fd.read(3).decode() - if header == 'FV ' : sample_size = 4 # floats - elif header == 'DV ' : sample_size = 8 # doubles - else : raise UnknownVectorHeader("The header contained '%s'" % header) - assert (sample_size > 0) - # Dimension, - assert (fd.read(1).decode() == '\4'); # int-size - vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim - # Read whole vector, - buf = fd.read(vec_size * sample_size) - if sample_size == 4 : ans = np.frombuffer(buf, dtype='float32') - elif sample_size == 8 : ans = np.frombuffer(buf, dtype='float64') - else : raise BadSampleSize - return ans - - -# Writing, -def write_vec_flt(file_or_fd, v, key=''): - """ write_vec_flt(f, v, key='') - Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats. - Arguments: - file_or_fd : filename or opened file descriptor for writing, - v : the vector to be stored, - key (optional) : used for writing ark-file, the utterance-id gets written before the vector. - - Example of writing single vector: - kaldi_io.write_vec_flt(filename, vec) - - Example of writing arkfile: - with open(ark_file,'w') as f: - for key,vec in dict.iteritems(): - kaldi_io.write_vec_flt(f, vec, key=key) - """ - fd = open_or_fd(file_or_fd, mode='wb') - if sys.version_info[0] == 3: assert(fd.mode == 'wb') - try: - if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), - fd.write('\0B'.encode()) # we write binary! - # Data-type, - if v.dtype == 'float32': fd.write('FV '.encode()) - elif v.dtype == 'float64': fd.write('DV '.encode()) - else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % v.dtype) - # Dim, - fd.write('\04'.encode()) - fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim - # Data, - fd.write(v.tobytes()) - finally: - if fd is not file_or_fd : fd.close() - - -################################################# -# Float matrices (features, transformations, ...), - -# Reading, -def read_mat_scp(file_or_fd): - """ generator(key,mat) = read_mat_scp(file_or_fd) - Returns generator of (key,matrix) tuples, read according to kaldi scp. - file_or_fd : scp, gzipped scp, pipe or opened file descriptor. - - Iterate the scp: - for key,mat in kaldi_io.read_mat_scp(file): - ... - - Read scp to a 'dictionary': - d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } - """ - fd = open_or_fd(file_or_fd) - try: - for line in fd: - (key,rxfile) = line.decode().split(' ') - mat = read_mat(rxfile) - yield key, mat - finally: - if fd is not file_or_fd : fd.close() - -def read_mat_ark(file_or_fd): - """ generator(key,mat) = read_mat_ark(file_or_fd) - Returns generator of (key,matrix) tuples, read from ark file/stream. - file_or_fd : scp, gzipped scp, pipe or opened file descriptor. - - Iterate the ark: - for key,mat in kaldi_io.read_mat_ark(file): - ... - - Read ark to a 'dictionary': - d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) } - """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - mat = read_mat(fd) - yield key, mat - key = read_key(fd) - finally: - if fd is not file_or_fd : fd.close() - -def read_mat(file_or_fd): - """ [mat] = read_mat(file_or_fd) - Reads single kaldi matrix, supports ascii and binary. - file_or_fd : file, gzipped file, pipe or opened file descriptor. - """ - fd = open_or_fd(file_or_fd) - try: - binary = fd.read(2).decode() - if binary == '\0B' : - mat = _read_mat_binary(fd) - else: - assert(binary == ' [') - mat = _read_mat_ascii(fd) - finally: - if fd is not file_or_fd: fd.close() - return mat - -def _read_mat_binary(fd): - # Data type - header = fd.read(3).decode() - # 'CM', 'CM2', 'CM3' are possible values, - if header.startswith('CM'): return _read_compressed_mat(fd, header) - elif header == 'FM ': sample_size = 4 # floats - elif header == 'DM ': sample_size = 8 # doubles - else: raise UnknownMatrixHeader("The header contained '%s'" % header) - assert(sample_size > 0) - # Dimensions - s1, rows, s2, cols = np.frombuffer(fd.read(10), dtype='int8,int32,int8,int32', count=1)[0] - # Read whole matrix - buf = fd.read(rows * cols * sample_size) - if sample_size == 4 : vec = np.frombuffer(buf, dtype='float32') - elif sample_size == 8 : vec = np.frombuffer(buf, dtype='float64') - else : raise BadSampleSize - mat = np.reshape(vec,(rows,cols)) - return mat - -def _read_mat_ascii(fd): - rows = [] - while 1: - line = fd.readline().decode() - if (len(line) == 0) : raise BadInputFormat # eof, should not happen! - if len(line.strip()) == 0 : continue # skip empty line - arr = line.strip().split() - if arr[-1] != ']': - rows.append(np.array(arr,dtype='float32')) # not last line - else: - rows.append(np.array(arr[:-1],dtype='float32')) # last line - mat = np.vstack(rows) - return mat - - -def _read_compressed_mat(fd, format): - """ Read a compressed matrix, - see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h - methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...), - """ - assert(format == 'CM ') # The formats CM2, CM3 are not supported... - - # Format of header 'struct', - global_header = np.dtype([('minvalue','float32'),('range','float32'),('num_rows','int32'),('num_cols','int32')]) # member '.format' is not written, - per_col_header = np.dtype([('percentile_0','uint16'),('percentile_25','uint16'),('percentile_75','uint16'),('percentile_100','uint16')]) - - # Read global header, - globmin, globrange, rows, cols = np.frombuffer(fd.read(16), dtype=global_header, count=1)[0] - - # The data is structed as [Colheader, ... , Colheader, Data, Data , .... ] - # { cols }{ size } - col_headers = np.frombuffer(fd.read(cols*8), dtype=per_col_header, count=cols) - col_headers = np.array([np.array([x for x in y]) * globrange * 1.52590218966964e-05 + globmin for y in col_headers], dtype=np.float32) - data = np.reshape(np.frombuffer(fd.read(cols*rows), dtype='uint8', count=cols*rows), newshape=(cols,rows)) # stored as col-major, - - mat = np.zeros((cols,rows), dtype='float32') - p0 = col_headers[:, 0].reshape(-1, 1) - p25 = col_headers[:, 1].reshape(-1, 1) - p75 = col_headers[:, 2].reshape(-1, 1) - p100 = col_headers[:, 3].reshape(-1, 1) - mask_0_64 = (data <= 64) - mask_193_255 = (data > 192) - mask_65_192 = (~(mask_0_64 | mask_193_255)) - - mat += (p0 + (p25 - p0) / 64. * data) * mask_0_64.astype(np.float32) - mat += (p25 + (p75 - p25) / 128. * (data - 64)) * mask_65_192.astype(np.float32) - mat += (p75 + (p100 - p75) / 63. * (data - 192)) * mask_193_255.astype(np.float32) - - return mat.T # transpose! col-major -> row-major, - - -# Writing, -def write_mat(file_or_fd, m, key=''): - """ write_mat(f, m, key='') - Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats. - Arguments: - file_or_fd : filename of opened file descriptor for writing, - m : the matrix to be stored, - key (optional) : used for writing ark-file, the utterance-id gets written before the matrix. - - Example of writing single matrix: - kaldi_io.write_mat(filename, mat) - - Example of writing arkfile: - with open(ark_file,'w') as f: - for key,mat in dict.iteritems(): - kaldi_io.write_mat(f, mat, key=key) - """ - fd = open_or_fd(file_or_fd, mode='wb') - if sys.version_info[0] == 3: assert(fd.mode == 'wb') - try: - if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), - fd.write('\0B'.encode()) # we write binary! - # Data-type, - if m.dtype == 'float32': fd.write('FM '.encode()) - elif m.dtype == 'float64': fd.write('DM '.encode()) - else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % m.dtype) - # Dims, - fd.write('\04'.encode()) - fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows - fd.write('\04'.encode()) - fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols - # Data, - fd.write(m.tobytes()) - finally: - if fd is not file_or_fd : fd.close() - - -################################################# -# 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...) -# Corresponds to: vector > > -# - outer vector: time axis -# - inner vector: records at the time -# - tuple: int = index, float = value -# - -def read_cnet_ark(file_or_fd): - """ Alias of function 'read_post_ark()', 'cnet' = confusion network """ - return read_post_ark(file_or_fd) - -def read_post_ark(file_or_fd): - """ generator(key,vec>) = read_post_ark(file) - Returns generator of (key,posterior) tuples, read from ark file. - file_or_fd : ark, gzipped ark, pipe or opened file descriptor. - - Iterate the ark: - for key,post in kaldi_io.read_post_ark(file): - ... - - Read ark to a 'dictionary': - d = { key:post for key,post in kaldi_io.read_post_ark(file) } - """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - post = read_post(fd) - yield key, post - key = read_key(fd) - finally: - if fd is not file_or_fd: fd.close() - -def read_post(file_or_fd): - """ [post] = read_post(file_or_fd) - Reads single kaldi 'Posterior' in binary format. - - The 'Posterior' is C++ type 'vector > >', - the outer-vector is usually time axis, inner-vector are the records - at given time, and the tuple is composed of an 'index' (integer) - and a 'float-value'. The 'float-value' can represent a probability - or any other numeric value. - - Returns vector of vectors of tuples. - """ - fd = open_or_fd(file_or_fd) - ans=[] - binary = fd.read(2).decode(); assert(binary == '\0B'); # binary flag - assert(fd.read(1).decode() == '\4'); # int-size - outer_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) - - # Loop over 'outer-vector', - for i in range(outer_vec_size): - assert(fd.read(1).decode() == '\4'); # int-size - inner_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of records for frame (or bin) - data = np.frombuffer(fd.read(inner_vec_size*10), dtype=[('size_idx','int8'),('idx','int32'),('size_post','int8'),('post','float32')], count=inner_vec_size) - assert(data[0]['size_idx'] == 4) - assert(data[0]['size_post'] == 4) - ans.append(data[['idx','post']].tolist()) - - if fd is not file_or_fd: fd.close() - return ans - - -################################################# -# Kaldi Confusion Network bin begin/end times, -# (kaldi stores CNs time info separately from the Posterior). -# - -def read_cntime_ark(file_or_fd): - """ generator(key,vec>) = read_cntime_ark(file_or_fd) - Returns generator of (key,cntime) tuples, read from ark file. - file_or_fd : file, gzipped file, pipe or opened file descriptor. - - Iterate the ark: - for key,time in kaldi_io.read_cntime_ark(file): - ... - - Read ark to a 'dictionary': - d = { key:time for key,time in kaldi_io.read_post_ark(file) } - """ - fd = open_or_fd(file_or_fd) - try: - key = read_key(fd) - while key: - cntime = read_cntime(fd) - yield key, cntime - key = read_key(fd) - finally: - if fd is not file_or_fd : fd.close() - -def read_cntime(file_or_fd): - """ [cntime] = read_cntime(file_or_fd) - Reads single kaldi 'Confusion Network time info', in binary format: - C++ type: vector >. - (begin/end times of bins at the confusion network). - - Binary layout is ' ...' - - file_or_fd : file, gzipped file, pipe or opened file descriptor. - - Returns vector of tuples. - """ - fd = open_or_fd(file_or_fd) - binary = fd.read(2).decode(); assert(binary == '\0B'); # assuming it's binary - - assert(fd.read(1).decode() == '\4'); # int-size - vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) - - data = np.frombuffer(fd.read(vec_size*10), dtype=[('size_beg','int8'),('t_beg','float32'),('size_end','int8'),('t_end','float32')], count=vec_size) - assert(data[0]['size_beg'] == 4) - assert(data[0]['size_end'] == 4) - ans = data[['t_beg','t_end']].tolist() # Return vector of tuples (t_beg,t_end), - - if fd is not file_or_fd : fd.close() - return ans - - -################################################# -# Segments related, -# - -# Segments as 'Bool vectors' can be handy, -# - for 'superposing' the segmentations, -# - for frame-selection in Speaker-ID experiments, -def read_segments_as_bool_vec(segments_file): - """ [ bool_vec ] = read_segments_as_bool_vec(segments_file) - using kaldi 'segments' file for 1 wav, format : ' ' - - t-beg, t-end is in seconds, - - assumed 100 frames/second, - """ - segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1) - # Sanity checks, - assert(len(segs) > 0) # empty segmentation is an error, - assert(len(np.unique([rec[1] for rec in segs ])) == 1) # segments with only 1 wav-file, - # Convert time to frame-indexes, - start = np.rint([100 * rec[2] for rec in segs]).astype(int) - end = np.rint([100 * rec[3] for rec in segs]).astype(int) - # Taken from 'read_lab_to_bool_vec', htk.py, - frms = np.repeat(np.r_[np.tile([False,True], len(end)), False], - np.r_[np.c_[start - np.r_[0, end[:-1]], end-start].flat, 0]) - assert np.sum(end-start) == np.sum(frms) - return frms - diff --git a/speech_tools/text2token.py b/speech_tools/text2token.py index 54ddd99e3..06571237b 100755 --- a/speech_tools/text2token.py +++ b/speech_tools/text2token.py @@ -39,7 +39,10 @@ def main(args): entry = line.rstrip().split() tokenized = tokenize(' '.join(entry[args.skip_ncols:]), space=args.space, non_lang_syms=nls) - print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) + if args.skip_ncols > 0: + print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) + else: + print(tokenized) if __name__ == '__main__': diff --git a/speech_train.py b/speech_train.py index 6f0116c96..b327c602e 100755 --- a/speech_train.py +++ b/speech_train.py @@ -35,7 +35,6 @@ def main(args, init_distributed=False): if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) - torch.backends.cudnn.deterministic = True if args.disable_cudnn: torch.backends.cudnn.enabled = False diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py index 74f075896..4e66b5f96 100644 --- a/tests/test_speech_utils.py +++ b/tests/test_speech_utils.py @@ -77,7 +77,8 @@ def test_speech_tokenizer(self): # test :func:`~speech_tools.utils.tokenize` with # :func:`~TokenDictionary.encode_line` - tensor = self.dict.encode_line(tokens, append_eos=True) + tensor = self.dict.encode_line(tokens, add_if_not_exist=False, + append_eos=True) reconstructed_tokens = self.dict.string(tensor) expected_tokens = ' '.join( [token if self.dict.index(token) != self.dict.unk() else \ From 0e2926b90b4b5cd6f920eb02e7ee492484e0a8b0 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 4 Apr 2019 14:59:05 -0400 Subject: [PATCH 017/119] code adaptation/changes according to the commits from Apr 1, 2019 to Apr 10, 2019 --- fairseq/data/speech_dataset.py | 26 +++++++++++------- speech_train.py | 48 ++++++++++++++++++---------------- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index 17deae6aa..e6d7d0f82 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -80,6 +80,20 @@ def merge(key, left_pad, move_eos_to_beginning=False): return batch +def generate_dummy_batch(num_tokens, collate_fn, feat_dim, max_sentences=16, src_len=300, dict=None, tgt_len=30): + """Return a dummy batch with a given number of tokens.""" + bsz = max(min(num_tokens // src_len, max_sentences), 1) + return collate_fn([ + { + 'id': i, + 'utt_id': 'dummy' + str(i), + 'source': torch.FloatTensor(src_len, feat_dim).uniform_(-10.0, 10.0), + 'target': dict.dummy_sentence(tgt_len) if dict is not None else None, + } + for i in range(bsz) + ]) + + class SpeechDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. @@ -202,16 +216,8 @@ def get_dummy_batch(self, num_tokens, max_positions, max_sentences=16, src_len=3 max_positions, (self.max_source_positions, self.max_target_positions), ) - bsz = max(min(num_tokens // src_len, max_sentences), 1) - return self.collater([ - { - 'id': i, - 'utt_id': 'dummy' + str(i), - 'source': torch.FloatTensor(src_len, self.src.feat_dim).uniform_(-10.0, 10.0), - 'target': self.dict.dummy_sentence(tgt_len) if self.dict is not None else None, - } - for i in range(bsz) - ]) + return generate_dummy_batch(num_tokens, self.collater, self.src.feat_dim, + max_sentences, src_len, self.dict, tgt_len) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to diff --git a/speech_train.py b/speech_train.py index b327c602e..9ea3b9c93 100755 --- a/speech_train.py +++ b/speech_train.py @@ -42,13 +42,7 @@ def main(args, init_distributed=False): task = tasks.setup_task(args) # Load dataset splits - load_dataset_splits(task, ['train', 'valid']) - - # Initialize distributed training (after data loading) - if init_distributed: - import socket - args.distributed_rank = distributed_utils.distributed_init(args) - print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) + load_dataset_splits(args, task) # Build model and criterion model = task.build_model(args) @@ -67,8 +61,8 @@ def main(args, init_distributed=False): task.max_positions(), model.max_positions(), ) - dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) - oom_batch = task.dataset('train').get_dummy_batch(1, max_positions) + dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions) + oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions) # Build trainer trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) @@ -92,6 +86,12 @@ def main(args, init_distributed=False): num_workers=args.num_workers, ) + # Initialize distributed training (after data loading) + if init_distributed: + import socket + args.distributed_rank = distributed_utils.distributed_init(args) + print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) + # Load the latest checkpoint if one is available if not load_checkpoint(args, trainer, epoch_itr): trainer.dummy_train_step([dummy_batch]) @@ -357,7 +357,11 @@ def save_checkpoint(args, trainer, epoch_itr, val_wer): def load_checkpoint(args, trainer, epoch_itr): """Load a checkpoint and replay dataloader to match.""" - os.makedirs(args.save_dir, exist_ok=True) + + # Only rank 0 should attempt to create the required dir + if args.distributed_rank == 0: + os.makedirs(args.save_dir, exist_ok=True) + if os.path.isabs(args.restore_file): checkpoint_path = args.restore_file else: @@ -382,19 +386,17 @@ def load_checkpoint(args, trainer, epoch_itr): return False -def load_dataset_splits(task, splits): - for split in splits: - if split == 'train': - task.load_dataset(split, combine=True) - else: - for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') - try: - task.load_dataset(split_k, combine=False) - except FileNotFoundError as e: - if k > 0: - break - raise e +def load_dataset_splits(args, task): + task.load_dataset(args.train_subset, combine=True) + for split in args.valid_subset.split(','): + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else '') + try: + task.load_dataset(split_k, combine=False) + except FileNotFoundError as e: + if k > 0: + break + raise e def distributed_main(i, args): From 9e479ca1068925be53de75d72c7fb005457a0310 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 14 Apr 2019 21:33:38 -0400 Subject: [PATCH 018/119] word lm related; add unigram/temporal label smoothing and update the wsj recipe accordingly; code adaptation/changes according to the commits from Apr 12, 2019 to May 21, 2019 --- examples/asr_wsj/cmd.sh | 8 +- examples/asr_wsj/local/score.sh | 45 ++ examples/asr_wsj/local/wsj_format_data.sh | 1 - examples/asr_wsj/path.sh | 1 + examples/asr_wsj/run.sh | 226 ++++++--- fairseq/criterions/cross_entropy_with_wer.py | 20 +- .../label_smoothed_cross_entropy_with_wer.py | 65 ++- fairseq/data/scp_dataset.py | 12 +- fairseq/data/speech_dataset.py | 26 - fairseq/data/token_dictionary.py | 17 +- fairseq/models/external_language_model.py | 474 ++++++++++++++++++ fairseq/models/speech_lstm.py | 217 ++++++-- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 7 +- fairseq/tasks/language_modeling_for_asr.py | 9 +- fairseq/tasks/speech_recognition.py | 148 ++++-- fairseq/wer.py | 21 +- speech_recognize.py | 72 ++- speech_tools/Makefile | 2 +- speech_tools/compute_wer.py | 100 ++++ speech_tools/text2token.py | 3 +- speech_tools/text2vocabulary.py | 106 ++++ speech_tools/utils.py | 58 ++- speech_train.py | 219 ++------ 23 files changed, 1424 insertions(+), 433 deletions(-) create mode 100755 examples/asr_wsj/local/score.sh delete mode 120000 examples/asr_wsj/local/wsj_format_data.sh create mode 100644 fairseq/models/external_language_model.py create mode 100755 speech_tools/compute_wer.py create mode 100755 speech_tools/text2vocabulary.py diff --git a/examples/asr_wsj/cmd.sh b/examples/asr_wsj/cmd.sh index 008ac4efa..b14280b96 100644 --- a/examples/asr_wsj/cmd.sh +++ b/examples/asr_wsj/cmd.sh @@ -10,11 +10,11 @@ # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. -#export train_cmd="run.pl --mem 2G" -#export cuda_cmd="run.pl --mem 2G --gpu 1" +#export train_cmd="run.pl --mem 4G" +#export cuda_cmd="run.pl --mem 4G --gpu 1" #export decode_cmd="run.pl --mem 4G" # JHU setup -export train_cmd="queue.pl --mem 2G" -export cuda_cmd="queue.pl --mem 2G --gpu 1 --config conf/gpu.conf" +export train_cmd="queue.pl --mem 4G" +export cuda_cmd="queue.pl --mem 4G --gpu 1 --config conf/gpu.conf" export decode_cmd="queue.pl --mem 4G" diff --git a/examples/asr_wsj/local/score.sh b/examples/asr_wsj/local/score.sh new file mode 100755 index 000000000..80653d43d --- /dev/null +++ b/examples/asr_wsj/local/score.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Copyright (c) 2019-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +# begin configuration section. +cmd=run.pl +#end configuration section. + +echo "$0 $@" # Print the command line for logging +[ -f ./path.sh ] && . ./path.sh +. ./utils/parse_options.sh + +if [ $# -ne 2 ]; then + echo "Usage: $0 [--cmd (run.pl|queue.pl...)] decode-dir>" + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + exit 1; +fi + +data=$1 +dir=$2 + + +ref_filtering_cmd="cat" +[ -x local/wer_output_filter ] && ref_filtering_cmd="local/wer_output_filter" +[ -x local/wer_ref_filter ] && ref_filtering_cmd="local/wer_ref_filter" +hyp_filtering_cmd="cat" +[ -x local/wer_output_filter ] && hyp_filtering_cmd="local/wer_output_filter" +[ -x local/wer_hyp_filter ] && hyp_filtering_cmd="local/wer_hyp_filter" + +mkdir -p $dir/scoring_kaldi/log +$ref_filtering_cmd $data/text > $dir/scoring_kaldi/test_filt.txt || exit 1; +$hyp_filtering_cmd $dir/decoded_results.txt > $dir/scoring_kaldi/hyp_filt.txt || exit 1; + +$cmd $dir/scoring_kaldi/log/score.log \ + cat $dir/scoring_kaldi/hyp_filt.txt \| \ + compute-wer --text --mode=present \ + ark:$dir/scoring_kaldi/test_filt.txt ark,p:- ">&" $dir/scoring_kaldi/wer || exit 1; + diff --git a/examples/asr_wsj/local/wsj_format_data.sh b/examples/asr_wsj/local/wsj_format_data.sh deleted file mode 120000 index 710e8f3a0..000000000 --- a/examples/asr_wsj/local/wsj_format_data.sh +++ /dev/null @@ -1 +0,0 @@ -../../../speech_tools/kaldi/egs/wsj/s5/local/wsj_format_data.sh \ No newline at end of file diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh index 19a6eb6eb..80ed44a7f 100644 --- a/examples/asr_wsj/path.sh +++ b/examples/asr_wsj/path.sh @@ -11,5 +11,6 @@ export LC_ALL=C export PATH=~/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index f58bae379..177d28077 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -10,21 +10,24 @@ set -e -o pipefail stage=0 -free_gpu= +ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned on CLSP grid -# e2e model related +# E2E model related affix= train_set=train_si284 valid_set=test_dev93 test_set=test_eval92 checkpoint=checkpoint_best.pt -validate_on_train=false -disable_cudnn=false +validate_on_train_subset=false # for monitoring E2E model training # LM related lm_affix= lm_checkpoint=checkpoint_best.pt -lm_shallow_fusion=false +lm_shallow_fusion=true # no LM fusion if false +use_wordlm=true # Only relevant when LM fusion is enabled. Use char LM if false +wordlm_affix= +wordlm_vocabsize=65000 # data related dumpdir=data/dump # directory to dump full features @@ -34,6 +37,9 @@ if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then wsj0=/export/corpora5/LDC/LDC93S6B wsj1=/export/corpora5/LDC/LDC94S13B fi +train_subset_size=500 # for validation if validate_on_train_subset is set to true +kaldi_scoring=true + # feature configuration do_delta=false @@ -43,6 +49,7 @@ do_delta=false . ./utils/parse_options.sh lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} +wordlmdir=exp/wordlm_lstm${wordlm_affix:+_${wordlm_affix}} dir=exp/lstm${affix:+_$affix} if [ ${stage} -le 0 ]; then @@ -50,10 +57,21 @@ if [ ${stage} -le 0 ]; then ### But you can utilize Kaldi recipes in most cases echo "Stage 0: Data Preparation" local/wsj_data_prep.sh ${wsj0}/??-{?,??}.? ${wsj1}/??-{?,??}.? - local/wsj_format_data.sh + echo "Preparing train and test data" + srcdir=data/local/data + for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do + mkdir -p data/$x + cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1; + cp $srcdir/$x.txt data/$x/text || exit 1; + cp $srcdir/$x.spk2utt data/$x/spk2utt || exit 1; + cp $srcdir/$x.utt2spk data/$x/utt2spk || exit 1; + utils/filter_scp.pl data/$x/spk2utt $srcdir/spk2gender > data/$x/spk2gender || exit 1; + done + echo "Succeeded in formatting data." fi train_feat_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${train_feat_dir} +train_subset_feat_dir=${dumpdir}/${train_set}_${train_subset_size}/delta${do_delta}; mkdir -p ${train_subset_feat_dir} valid_feat_dir=${dumpdir}/${valid_set}/delta${do_delta}; mkdir -p ${valid_feat_dir} test_feat_dir=${dumpdir}/${test_set}/delta${do_delta}; mkdir -p ${test_feat_dir} if [ ${stage} -le 1 ]; then @@ -82,11 +100,18 @@ if [ ${stage} -le 1 ]; then data/${valid_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/valid ${valid_feat_dir} dump.sh --cmd "$train_cmd" --nj 4 --do_delta $do_delta \ data/${test_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/test ${test_feat_dir} + + # randomly select a subset of train set for optional diagnosis + utils/subset_data_dir.sh data/${train_set} ${train_subset_size} data/${train_set}_${train_subset_size} + utils/filter_scp.pl data/${train_set}_${train_subset_size}/utt2spk ${train_feat_dir}/feats.scp \ + > ${train_subset_feat_dir}/feats.scp fi dict=data/lang/${train_set}_units.txt nlsyms=data/lang/non_lang_syms.txt lmdatadir=data/lm_text +wordlmdict=data/lang/wordlist_$wordlm_vocabsize.txt +wordlmdatadir=data/wordlm_text if [ ${stage} -le 2 ]; then echo "Stage 2: Dictionary Preparation and Text Tokenization" mkdir -p data/lang @@ -97,7 +122,7 @@ if [ ${stage} -le 2 ]; then cat $nlsyms echo "$0: making a dictionary and tokenizing text for train/valid/test set..." - for dataset in $train_set $valid_set $test_set; do + for dataset in $train_set ${train_set}_${train_subset_size} $valid_set $test_set; do text=data/$dataset/text token_text=data/$dataset/token_text text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms $text > $token_text @@ -108,114 +133,177 @@ if [ ${stage} -le 2 ]; then fi done - echo "$0: preparing text for LM..." - mkdir -p $lmdatadir - for dataset in $train_set $valid_set $test_set; do - token_text=data/$dataset/token_text - cut -f 2- -d" " $token_text > $lmdatadir/$dataset.tokens - done - zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \ - | grep -v "<" | tr "[:lower:]" "[:upper:]" \ - | text2token.py --space "" > $lmdatadir/train_others.tokens - cat $lmdatadir/$train_set.tokens $lmdatadir/train_others.tokens > $lmdatadir/train.tokens + if ! $use_wordlm; then + echo "$0: preparing text for char LM..." + mkdir -p $lmdatadir + for dataset in $train_set $valid_set $test_set; do + token_text=data/$dataset/token_text + cut -f 2- -d" " $token_text > $lmdatadir/$dataset.tokens + done + zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \ + | grep -v "<" | tr "[:lower:]" "[:upper:]" \ + | text2token.py --space "" > $lmdatadir/train_others.tokens + cat $lmdatadir/$train_set.tokens $lmdatadir/train_others.tokens > $lmdatadir/train.tokens + else + echo "$0: preparing text and making word dictionary for word LM..." + mkdir -p $wordlmdatadir + for dataset in $train_set $valid_set $test_set; do + text=data/$dataset/text + cut -f 2- -d" " $text > $wordlmdatadir/$dataset + done + zcat ${wsj1}/13-32.1/wsj1/doc/lng_modl/lm_train/np_data/{87,88,89}/*.z \ + | grep -v "<" | tr "[:lower:]" "[:upper:]" > $wordlmdatadir/train_others + cat $wordlmdatadir/$train_set $wordlmdatadir/train_others > $wordlmdatadir/train + text2vocabulary.py --vocabsize $wordlm_vocabsize --exclude " " \ + --valid-text $wordlmdatadir/$valid_set --test-text $wordlmdatadir/$test_set \ + $wordlmdatadir/train > $wordlmdict \ + 2> >(tee $(dirname $wordlmdict)/vocab${wordlm_vocabsize}_stats.log >&2) + fi fi lmdict=$dict if [ ${stage} -le 3 ]; then echo "Stage 3: Text Binarization for LM Training" - lmdict=$dict - mkdir -p $lmdatadir/logs - ${decode_cmd} $lmdatadir/logs/preprocess.log \ - python3 ../../preprocess.py --task language_modeling_for_asr \ - --workers 30 --srcdict $lmdict --only-source \ - --trainpref $lmdatadir/train.tokens \ - --validpref $lmdatadir/$valid_set.tokens \ - --testpref $lmdatadir/$test_set.tokens \ - --destdir $lmdatadir + if ! $use_wordlm; then + echo "$0: binarizing char text..." + mkdir -p $lmdatadir/logs + ${decode_cmd} $lmdatadir/logs/preprocess.log \ + python3 ../../preprocess.py --task language_modeling_for_asr \ + --workers 30 --srcdict $lmdict --only-source \ + --trainpref $lmdatadir/train.tokens \ + --validpref $lmdatadir/$valid_set.tokens \ + --testpref $lmdatadir/$test_set.tokens \ + --destdir $lmdatadir + else + echo "$0: binarizing word text..." + mkdir -p $wordlmdatadir/logs + ${decode_cmd} $wordlmdatadir/logs/preprocess.log \ + python3 ../../preprocess.py --task language_modeling_for_asr \ + --workers 30 --srcdict $wordlmdict --only-source \ + --trainpref $wordlmdatadir/train \ + --validpref $wordlmdatadir/$valid_set \ + --testpref $wordlmdatadir/$test_set \ + --destdir $wordlmdatadir + fi fi -if [ ${stage} -le 4 ]; then - echo "Stage 4: LM Training" +[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) +[ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; +[ $(echo $free_gpu | sed 's/,/ /g' | awk '{print NF}') -ne "$ngpus" ] && \ + echo "number of GPU ids in --free-gpu=$free_gpu does not match --ngpus=$ngpus" && exit 1; + +if [ ${stage} -le 4 ] && ! $use_wordlm; then + echo "Stage 4: char LM Training" valid_subset=valid mkdir -p $lmdir/logs log_file=$lmdir/logs/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) - [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval 2000 --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 128 \ --valid-subset $valid_subset --max-sentences-valid 256 \ - --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ - --arch lstm_lm_wsj --dropout 0.1 --criterion cross_entropy \ - --sample-break-mode eos 2>&1 | tee $log_file + --arch lstm_lm_wsj --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file fi -if [ ${stage} -le 5 ]; then - echo "Stage 5: LM Evaluation" +if [ ${stage} -le 5 ] && ! $use_wordlm; then + echo "Stage 5: char LM Evaluation" for gen_subset in valid test; do log_file=$lmdir/logs/evaluation_$gen_subset.log - [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) - [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../eval_lm.py $lmdatadir \ + python3 ../../eval_lm.py $lmdatadir --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ --max-tokens 192000 --max-sentences 256 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file done fi +if [ ${stage} -le 6 ] && $use_wordlm; then + echo "Stage 6: word LM Training" + valid_subset=valid + mkdir -p $wordlmdir/logs + log_file=$wordlmdir/logs/train.log + [ -f $wordlmdir/checkpoint_last.pt ] && log_file="-a $log_file" + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 \ + --task language_modeling_for_asr --dict $wordlmdict \ + --log-interval 2000 --log-format simple \ + --num-workers 0 --max-tokens 6400 --max-sentences 256 \ + --valid-subset $valid_subset --max-sentences-valid 512 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --max-epoch 20 --optimizer adam --lr 0.001 --weight-decay 1e-05 \ + --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ + --save-dir $wordlmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ + --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --arch lstm_wordlm_wsj --criterion cross_entropy \ + --sample-break-mode eos 2>&1 | tee $log_file +fi + +if [ ${stage} -le 7 ] && $use_wordlm; then + echo "Stage 7: word LM Evaluation" + for gen_subset in valid test; do + log_file=$wordlmdir/logs/evaluation_$gen_subset.log + python3 ../../eval_lm.py $wordlmdatadir --cpu \ + --task language_modeling_for_asr --dict $wordlmdict --gen-subset $gen_subset \ + --max-tokens 12800 --max-sentences 512 --sample-break-mode eos \ + --path $wordlmdir/$lm_checkpoint 2>&1 | tee $log_file + done +fi + train_feat=$train_feat_dir/feats.scp train_token_text=data/$train_set/token_text valid_feat=$valid_feat_dir/feats.scp valid_token_text=data/$valid_set/token_text -if [ ${stage} -le 6 ]; then - echo "Stage 6: Model Training" +if [ ${stage} -le 8 ]; then + echo "Stage 8: Model Training" + opts="" valid_subset=valid - if $validate_on_train; then - valid_subset="$valid_subset,train" + if $validate_on_train_subset; then + valid_subset="$valid_subset,train_subset" + opts="$opts --train-subset-feat-files $train_subset_feat_dir/feats.scp" + opts="$opts --train-subset-text-files data/${train_set}_${train_subset_size}/token_text" fi + [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - opts="" - $disable_cudnn && opts="$opts --disable-cudnn" - [ -f local/wer_output_filter ] && \ - opts="$opts --wer-output-filter local/wer_output_filter" - [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) - [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ - --log-interval 500 --log-format simple --print-training-sample-interval 1000 \ + --log-interval 1000 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 \ --valid-subset $valid_subset --max-sentences-valid 64 \ - --distributed-world-size 1 --distributed-rank 0 --distributed-port -1 \ - --max-epoch 20 --optimizer adam --lr 0.001 --weight-decay 0.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-4 --start-reduce-lr-epoch 11 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 200 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ - --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer --label-smoothing 0.05 \ + --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer \ + --label-smoothing 0.05 --smoothing-type temporal \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ - --dict $dict --non-lang-syms $nlsyms --dropout 0.2 \ + --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi -if [ ${stage} -le 7 ]; then - echo "Stage 7: Decoding" +if [ ${stage} -le 9 ]; then + echo "Stage 9: Decoding" opts="" path=$dir/$checkpoint decode_affix= if $lm_shallow_fusion; then - path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lprob-weights 1,0.29" - decode_affix=shallow_fusion + if ! $use_wordlm; then + path="$path:$lmdir/$lm_checkpoint" + opts="$opts --lm-weight 0.7 --coverage-weight 0.01" + decode_affix=shallow_fusion + else + path="$path:$wordlmdir/$lm_checkpoint" + opts="$opts --word-dict $wordlmdict --lm-weight 0.7 --oov-penalty 1e-4 --coverage-weight 0.01" + decode_affix=shallow_fusion_wordlm + fi fi - [ -f local/wer_output_filter ] && \ - opts="$opts --wer-output-filter local/wer_output_filter" + [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $valid_set $test_set; do if [ "$dataset" == "$valid_set" ]; then feat=$valid_feat_dir/feats.scp @@ -223,15 +311,19 @@ if [ ${stage} -le 7 ]; then feat=$test_feat_dir/feats.scp fi text=data/$dataset/token_text - [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu) - [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; - CUDA_VISIBLE_DEVICES=$free_gpu speech_recognize.py \ - --max-tokens 45000 --max-sentences 1 --num-shards 1 --shard-id 0 \ + CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ + --max-tokens 20000 --max-sentences 32 --num-shards 1 --shard-id 0 \ --test-feat-files $feat --test-text-files $text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ - --path $path --beam 50 --max-len-a 0.5 --max-len-b 0 --lenpen 1.1 --no-early-stop \ + --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ --print-alignment 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + + if $kaldi_scoring; then + echo "verify WER by scoring with Kaldi..." + local/score.sh data/$dataset $dir/decode_$dataset${decode_affix:+_${decode_affix}} + cat $dir/decode_$dataset${decode_affix:+_${decode_affix}}/scoring_kaldi/wer + fi done fi diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 52a21236b..07412bee3 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -5,7 +5,6 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -import math import numpy as np import torch import torch.nn.functional as F @@ -24,7 +23,7 @@ class CrossEntropyWithWERCriterion(CrossEntropyCriterion): def __init__(self, args, task): super().__init__(args, task) - dict = task.dict if hasattr(task, 'dict') else getattr(task, 'tgt_dict') + dict = task.target_dictionary self.scorer = wer.Scorer(dict, wer_output_filter=task.args.wer_output_filter) self.train_tgt_dataset = task.dataset(args.train_subset).tgt @@ -69,7 +68,7 @@ def forward(self, model, sample, reduce=True): # target, and the length of encoder_out if possible # and at least the length of target maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) - tokens = target.new_full([target.size(0), maxlen + 2], dict.pad()) + tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) tokens[:, 0] = dict.eos() lprobs = [] attn = [] if model.decoder.need_attn else None @@ -87,7 +86,6 @@ def forward(self, model, sample, reduce=True): break log_probs, attn_scores = self._decode(tokens[:, :step + 1], model, encoder_out, incremental_states) - log_probs[:, dict.pad()] = -math.inf # never select pad tokens[:, step + 1] = log_probs.argmax(-1) if step > 0: # deal with finished predictions # make log_probs uniform if the previous output token is EOS @@ -114,16 +112,20 @@ def forward(self, model, sample, reduce=True): self.scorer.reset() for i in range(target.size(0)): utt_id = sample['utt_id'][i] - id = sample['id'].data[i] + id = sample['id'].data[i].item() #ref_tokens = dict.string(target.data[i]) - ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dict.string(pred.data[i]) - self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) + # if it is a dummy batch (e.g., a "padding" batch in a sharded + # dataset), id might exceeds the dataset size; in that case we + # just skip it + if id < len( self.valid_tgt_dataset): + ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) + pred_tokens = dict.string(pred.data[i]) + self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) else: # print a randomly sampled result every print_interval updates assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) - id = sample['id'].data[i] + id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) ref_one = self.train_tgt_dataset.get_original_text(id, dict) diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 57caac238..7b421bd96 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -5,7 +5,6 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -import math import numpy as np import torch @@ -23,12 +22,17 @@ class LabelSmoothedCrossEntropyWithWERCriterion(LabelSmoothedCrossEntropyCriteri def __init__(self, args, task): super().__init__(args, task) - dict = task.dict if hasattr(task, 'dict') else getattr(task, 'tgt_dict') + dict = task.target_dictionary self.scorer = wer.Scorer(dict, wer_output_filter=task.args.wer_output_filter) - self.train_tgt_dataset = task.dataset(args.train_subset).tgt + self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 + if args.smoothing_type == 'unigram': + self.unigram_tensor = torch.cuda.FloatTensor(dict.count).unsqueeze(-1) \ + if torch.cuda.is_available() and not args.cpu \ + else torch.FloatTensor(dict.count).unsqueeze(-1) + self.unigram_tensor.div_(self.unigram_tensor.sum()) @staticmethod def add_args(parser): @@ -39,6 +43,9 @@ def add_args(parser): metavar='N', dest='print_interval', default=500, help='print a training sample (reference + ' 'prediction) every this number of updates') + parser.add_argument('--smoothing-type', type=str, default='uniform', + choices=['uniform', 'unigram', 'temporal'], + help='label smoothing type. Default: uniform') # fmt: on def forward(self, model, sample, reduce=True): @@ -69,7 +76,7 @@ def forward(self, model, sample, reduce=True): # target, and the length of encoder_out if possible # and at least the length of target maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) - tokens = target.new_full([target.size(0), maxlen + 2], dict.pad()) + tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) tokens[:, 0] = dict.eos() lprobs = [] attn = [] if model.decoder.need_attn else None @@ -87,7 +94,6 @@ def forward(self, model, sample, reduce=True): break log_probs, attn_scores = self._decode(tokens[:, :step + 1], model, encoder_out, incremental_states) - #log_probs[:, dict.pad()] = -math.inf # never select pad tokens[:, step + 1] = log_probs.argmax(-1) if step > 0: # deal with finished predictions # make log_probs uniform if the previous output token is EOS @@ -114,16 +120,20 @@ def forward(self, model, sample, reduce=True): self.scorer.reset() for i in range(target.size(0)): utt_id = sample['utt_id'][i] - id = sample['id'].data[i] + id = sample['id'].data[i].item() #ref_tokens = dict.string(target.data[i]) - ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dict.string(pred.data[i]) - self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) + # if it is a dummy batch (e.g., a "padding" batch in a sharded + # dataset), id might exceeds the dataset size; in that case we + # just skip it + if id < len( self.valid_tgt_dataset): + ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) + pred_tokens = dict.string(pred.data[i]) + self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) else: # print a randomly sampled result every print_interval updates assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) - id = sample['id'].data[i] + id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) ref_one = self.train_tgt_dataset.get_original_text(id, dict) @@ -132,15 +142,43 @@ def forward(self, model, sample, reduce=True): print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends + if self.args.smoothing_type == 'temporal': + # see https://arxiv.org/pdf/1612.02695.pdf + # prob_mask.dtype=int for deterministic behavior of Tensor.scatter_add_() + prob_mask = torch.zeros_like(lprobs, dtype=torch.int) # bsz x tgtlen x vocab_size + idx_tensor = target.new_full(target.size(), self.padding_idx).unsqueeze(-1) # bsz x tgtlen x 1 + # hard-code the remaining probabilty mass distributed symmetrically + # over neighbors at distance ±1 and ±2 with a 5 : 2 ratio + idx_tensor[:, 2:, 0] = target[:, :-2] # two neighbors to the left + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) + idx_tensor.fill_(self.padding_idx)[:, 1:, 0] = target[:, :-1] + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) + idx_tensor.fill_(self.padding_idx)[:, :-2, 0] = target[:, 2:] # two neighbors to the right + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) + idx_tensor.fill_(self.padding_idx)[:, :-1, 0] = target[:, 1:] + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) + prob_mask[:, :, self.padding_idx] = 0 # clear cumulative count on + prob_mask = prob_mask.float() # convert to float + sum_prob = prob_mask.sum(-1, keepdim=True) + sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # deal with "divided by 0" problem + prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1)) + lprobs = lprobs.view(-1, lprobs.size(-1)) target = target.view(-1, 1) non_pad_mask = target.ne(self.padding_idx) nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] - smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] + if self.args.smoothing_type == 'temporal': + smooth_loss = -lprobs.mul(prob_mask).sum(-1, keepdim=True)[non_pad_mask] + elif self.args.smoothing_type == 'unigram': + smooth_loss = -lprobs.matmul(self.unigram_tensor.to(lprobs))[non_pad_mask] + elif self.args.smoothing_type == 'uniform': + smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] + else: + raise ValueError('Unsupported smoothing type: {}'.format(self.args.smoothing_type)) if reduce: nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() - eps_i = self.eps / lprobs.size(-1) + eps_i = self.eps / lprobs.size(-1) if self.args.smoothing_type == 'uniform' else self.eps loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { @@ -188,6 +226,9 @@ def _decode(self, tokens, model, encoder_out, incremental_states): probs = probs[:, -1, :] return probs, attn + def set_train_tgt_dataset(self, dataset): + self.train_tgt_dataset = dataset + def set_valid_tgt_dataset(self, dataset): self.valid_tgt_dataset = dataset diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index 45ecde8b6..c8e4d3ee6 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -122,20 +122,20 @@ def __getitem__(self, i): self.start_pos_for_next_cache = pos_end \ if self.ordered_prefetch else 0 total_size = 0 - for idx in self.ordered_indices[pos_start : pos_end]: + for idx in self.ordered_indices[pos_start: pos_end]: total_size += self.sizes[idx] self.cache = np.empty((total_size, self.feat_dim), dtype=self.dtype) ptx = 0 self.cache_index.clear() - for idx in self.ordered_indices[pos_start : pos_end]: + for idx in self.ordered_indices[pos_start: pos_end]: self.cache_index[idx] = ptx length = self.sizes[idx] - dst = self.cache[ptx : ptx + length] + dst = self.cache[ptx: ptx + length] np.copyto(dst, kaldi_io.read_mat(self.extended_filenames[idx])) ptx += length ptx = self.cache_index[i] - a = self.cache[ptx : ptx + self.sizes[i]].copy() + a = self.cache[ptx: ptx + self.sizes[i]].copy() return torch.from_numpy(a).float() @@ -152,7 +152,7 @@ def read_data(self): dtype=self.dtype) for i in range(len(self.data_offsets)): ptx = self.data_offsets[i] - dst = self.buffer[ptx : ptx + self.sizes[i]] + dst = self.buffer[ptx: ptx + self.sizes[i]] np.copyto(dst, kaldi_io.read_mat(self.extended_filenames[i])) def filter_and_reorder(self, indices): @@ -162,7 +162,7 @@ def filter_and_reorder(self, indices): def __getitem__(self, i): self.check_index(i) ptx = self.data_offsets[i] - a = self.buffer[ptx : ptx + self.sizes[i]].copy() + a = self.buffer[ptx: ptx + self.sizes[i]].copy() return torch.from_numpy(a).float() diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index e6d7d0f82..699b96619 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -8,8 +8,6 @@ import numpy as np import torch -from fairseq import utils - from . import data_utils, FairseqDataset import speech_tools.utils as speech_utils @@ -80,20 +78,6 @@ def merge(key, left_pad, move_eos_to_beginning=False): return batch -def generate_dummy_batch(num_tokens, collate_fn, feat_dim, max_sentences=16, src_len=300, dict=None, tgt_len=30): - """Return a dummy batch with a given number of tokens.""" - bsz = max(min(num_tokens // src_len, max_sentences), 1) - return collate_fn([ - { - 'id': i, - 'utt_id': 'dummy' + str(i), - 'source': torch.FloatTensor(src_len, feat_dim).uniform_(-10.0, 10.0), - 'target': dict.dummy_sentence(tgt_len) if dict is not None else None, - } - for i in range(bsz) - ]) - - class SpeechDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. @@ -209,16 +193,6 @@ def collater(self, samples): input_feeding=self.input_feeding, ) - def get_dummy_batch(self, num_tokens, max_positions, max_sentences=16, src_len=300, tgt_len=30): - """Return a dummy batch with a given number of tokens.""" - src_len, tgt_len = utils.resolve_max_positions( - (src_len, tgt_len), - max_positions, - (self.max_source_positions, self.max_target_positions), - ) - return generate_dummy_batch(num_tokens, self.collater, self.src.feat_dim, - max_sentences, src_len, self.dict, tgt_len) - def num_tokens(self, index): """Return the number of frames in a sample. This value is used to enforce ``--max-tokens`` during batching.""" diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index c3c1b6dea..ba409180b 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -13,15 +13,16 @@ class TokenDictionary(Dictionary): """A mapping from symbols to consecutive integers""" - def __init__(self, pad='', eos='', unk='', space=''): - self.unk_word, self.pad_word, self.eos_word, self.space_word = \ - unk, pad, eos, space + + def __init__(self, pad='', eos='', unk='', bos='', space=''): + self.unk_word, self.pad_word, self.eos_word, self.bos_word, self.space_word = \ + unk, pad, eos, bos, space self.symbols = [] self.count = [] self.indices = {} - self.pad_index = self.add_symbol(pad) - self.eos_index = self.add_symbol(eos) - self.unk_index = self.add_symbol(unk) + self.pad_index = self.add_symbol(pad, n=0) + self.eos_index = self.add_symbol(eos, n=0) + self.unk_index = self.add_symbol(unk, n=0) self.nspecial = len(self.symbols) self.non_lang_syms = None @@ -45,6 +46,10 @@ def token_string(i): i != self.pad()) return data_utils.process_bpe_symbol(sent, bpe_symbol) + def bos(self): + """Disallow beginning-of-sentence symbol""" + raise NotImplementedError + def space(self): """Helper to get index of space symbol""" return self.space_index diff --git a/fairseq/models/external_language_model.py b/fairseq/models/external_language_model.py new file mode 100644 index 000000000..4c38e36ed --- /dev/null +++ b/fairseq/models/external_language_model.py @@ -0,0 +1,474 @@ +# Copyright (c) 2019-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils +from fairseq.data import TokenDictionary + +from . import FairseqIncrementalDecoder, FairseqLanguageModel + +from speech_tools.utils import tokenize, lexical_prefix_tree + + +def _clone_cached_state(cached_state): + if cached_state is None: + return None + + def clone_state(state): + if isinstance(state, list): + return [clone_state(state_i) for state_i in state] + return state.clone() if state is not None else None + + return tuple(map(clone_state, cached_state)) + + +class LookAheadWordLanguageModel(FairseqLanguageModel): + """A :class:`fairseq.models.FairseqLanguageModel` wrapper for + :class:`_LookAheadWordLanguageModelDecoder`. + """ + def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): + decoder = _LookAheadWordLanguageModelDecoder(wordlm, subword_dict, + oov_penalty, open_vocab) + super().__init__(decoder) + + +class _LookAheadWordLanguageModelDecoder(FairseqIncrementalDecoder): + """Look-ahead word language model decoder for end-to-end ASR. It is intended + to be used for beam search decoding. See https://arxiv.org/abs/1808.02608 + for details. + """ + def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): + super().__init__(wordlm.decoder.dictionary) + + assert isinstance(wordlm, FairseqLanguageModel) + self.lm_decoder = wordlm.decoder + assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ + callable(self.lm_decoder.masked_copy_incremental_state), \ + 'The wrapped decoder should implement masked_copy_incremental_state()' + self.oov_penalty = oov_penalty + self.open_vocab = open_vocab + self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue + + word_dict = self.lm_decoder.dictionary + assert isinstance(word_dict, TokenDictionary) + self.word_eos_idx = word_dict.eos() + self.word_unk_idx = word_dict.unk() + + assert isinstance(subword_dict, TokenDictionary) + self.subword_space_idx = subword_dict.space() + self.subword_eos_idx = subword_dict.eos() + self.subword_vocab_size = len(subword_dict) + + tokenizer = lambda x: tokenize( + x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) + + @torch.no_grad() + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): + assert incremental_state is not None, \ + 'this model is for incremental decoding only' + prev_output_tokens = prev_output_tokens[:, -1:] + bsz = prev_output_tokens.size(0) + + batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) + + cached_state = utils.get_incremental_state( + self.lm_decoder, incremental_state, 'cached_state') + + if cached_state is None: # it is the first time step + assert (prev_output_tokens == self.subword_eos_idx).all(), \ + 'expecting the input to the first time step to be ' + w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) + lm_out = self.lm_decoder(w, incremental_state=incremental_state) + cumsum_probs = torch.cumsum(self.lm_decoder.get_normalized_probs( + lm_out, log_probs=False, sample=None), dim=-1) # B x 1 x V + nodes = [self.lexroot] * bsz + else: + cumsum_probs = utils.get_incremental_state( + self, incremental_state, 'cumsum_probs') + nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + assert len(nodes) == bsz + w = prev_output_tokens.new([ + node.word_idx if node is not None and node.word_idx >= 0 else \ + self.word_unk_idx for node in nodes + ]).unsqueeze(-1) # B x 1 + old_cached_state = _clone_cached_state(cached_state) + lm_out = self.lm_decoder(w, incremental_state=incremental_state) + self.lm_decoder.masked_copy_incremental_state(incremental_state, + old_cached_state, batch_space_mask) + # recompute cumsum_probs from inter-word transition probabilities + # only for those whose prev_output_token is + cumsum_probs[batch_space_mask] = torch.cumsum( + self.lm_decoder.get_normalized_probs(lm_out, log_probs=False, + sample=None), + dim=-1, + )[batch_space_mask] + tokens_list = prev_output_tokens.squeeze(-1).tolist() + for i in range(bsz): + if tokens_list[i] == self.subword_space_idx: + # inter-word transition: go back to root + nodes[i] = self.lexroot + elif nodes[i] is not None and tokens_list[i] in nodes[i].children: + # intra-word transition: go to child + nodes[i] = nodes[i].children[tokens_list[i]] + else: # no path in the tree + nodes[i] = None + + utils.set_incremental_state( + self, incremental_state, 'cumsum_probs', cumsum_probs) + utils.set_incremental_state(self, incremental_state, 'nodes', nodes) + + # initialize out_probs (B x 1 x V) + if self.open_vocab: + # set out_probs to oov_penalty * P(|h) (case 3 in Eqn. 15) + out_probs = self.oov_penalty * ( + cumsum_probs[:, :, self.word_unk_idx] - \ + cumsum_probs[:, :, self.word_unk_idx - 1] + ).unsqueeze(-1).repeat(1, 1, self.subword_vocab_size) + # set the probability of emitting or to 0 if + # prev_output_tokens is or + batch_space_eos_mask = batch_space_mask | \ + prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + out_probs[batch_space_eos_mask, :, self.subword_space_idx] = self.zero + out_probs[batch_space_eos_mask, :, self.subword_eos_idx] = self.zero + # set transition probability to 1 for those whose node is out of the + # tree, i.e. node is None (case 4 in Eqn. 15) + batch_node_none_mask = [] + for node in nodes: + batch_node_none_mask.append(node is None) + batch_node_none_mask = batch_space_mask.new(batch_node_none_mask) + out_probs[batch_node_none_mask] = 1. + else: + # set out_probs to 0 + out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], + self.zero) + + # compute parent probabilities for those whose node is not None + sum_probs = cumsum_probs.new_full([bsz, 1], 1.) # default for root node + left_ranges, right_ranges, batch_node_not_root_mask = [], [], [] + for node in nodes: + if node is not None and node.word_set is not None: + left_ranges.append([node.word_set[0]]) + right_ranges.append([node.word_set[1]]) + batch_node_not_root_mask.append(True) + else: + batch_node_not_root_mask.append(False) + if len(left_ranges) > 0: + # b x 1 x 1 + left_ranges = prev_output_tokens.new(left_ranges).unsqueeze(-1) + right_ranges = prev_output_tokens.new(right_ranges).unsqueeze(-1) + batch_node_not_root_mask = batch_space_mask.new(batch_node_not_root_mask) + sum_probs[batch_node_not_root_mask] = ( + cumsum_probs[batch_node_not_root_mask].gather(-1, right_ranges) - \ + cumsum_probs[batch_node_not_root_mask].gather(-1, left_ranges) + ).squeeze(-1) + + # compute transition probabilities to child nodes (case 2 in Eqn. 15) + for i in range(bsz): + node = nodes[i] + if node is not None and len(node.children) > 0: + subword_idx, left_ranges, right_ranges = [], [], [] + for sidx, child in node.children.items(): + subword_idx.append(sidx) + left_ranges.append(child.word_set[0]) + right_ranges.append(child.word_set[1]) + subword_idx = prev_output_tokens.new(subword_idx) + left_ranges = prev_output_tokens.new(left_ranges) + right_ranges = prev_output_tokens.new(right_ranges) + out_probs[i, :, subword_idx] = \ + self.zero if sum_probs[i].item() < self.zero else \ + (cumsum_probs[i, :, right_ranges] - \ + cumsum_probs[i, :, left_ranges]) / sum_probs[i] + + # apply word-level probabilies for and (case 1 in Eqn. 15) + word_idx, batch_node_word_end_mask = [], [] + for node in nodes: + if node is not None and node.word_idx >= 0: + word_idx.append([node.word_idx]) + batch_node_word_end_mask.append(True) + else: + batch_node_word_end_mask.append(False) + if len(word_idx) > 0: + word_idx = prev_output_tokens.new(word_idx).unsqueeze(-1) # b x 1 x 1 + batch_node_word_end_mask = batch_space_mask.new(batch_node_word_end_mask) + word_probs = torch.where( + sum_probs[batch_node_word_end_mask] < self.zero, + cumsum_probs.new([self.zero]), + (cumsum_probs[batch_node_word_end_mask].gather(-1, word_idx) - \ + cumsum_probs[batch_node_word_end_mask].gather(-1, word_idx - 1) + ).squeeze(-1).div_(sum_probs[batch_node_word_end_mask]), + ) # b x 1 + out_probs[batch_node_word_end_mask, :, self.subword_space_idx] = word_probs + out_probs[batch_node_word_end_mask, :, self.subword_eos_idx] = word_probs + + # take log of probs and clip it from below to avoid log(0) + out_logprobs = torch.max(out_probs, out_probs.new([self.zero])).log_() + + # add log-probs of emitting word to that of emitting subword + cached_state = _clone_cached_state(utils.get_incremental_state( + self.lm_decoder, incremental_state, 'cached_state')) # for restore later + w = prev_output_tokens.new([ + node.word_idx if node is not None and node.word_idx >= 0 else \ + self.word_unk_idx for node in nodes + ]).unsqueeze(-1) # B x 1 + word_eos_logprobs = self.lm_decoder.get_normalized_probs( + self.lm_decoder(w, incremental_state=incremental_state), + log_probs=True, + sample=None, + )[:, :, self.word_eos_idx] + utils.set_incremental_state(self.lm_decoder, incremental_state, + 'cached_state', cached_state) # restore decoder's state + out_logprobs[:, :, self.subword_eos_idx] += word_eos_logprobs + + # note that here we return log-probs rather than logits, and the second + # element is None, which is usually a tensor of attention weights in + # attention-based models + return out_logprobs, None + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + + cumsum_probs = utils.get_incremental_state( + self, incremental_state, 'cumsum_probs') + if cumsum_probs is not None: + new_cumsum_probs = cumsum_probs.index_select(0, new_order) + utils.set_incremental_state(self, incremental_state, 'cumsum_probs', + new_cumsum_probs) + + nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + if nodes is not None: + new_order_list = new_order.tolist() + new_nodes = [nodes[i] for i in new_order_list] + utils.set_incremental_state(self, incremental_state, 'nodes', + new_nodes) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + # in-place op as not being used for backprop + return net_output[0] if log_probs else net_output[0].exp_() + + def max_positions(self): + return int(1e5) # an arbitrary large number + + +class MultiLevelLanguageModel(FairseqLanguageModel): + """A :class:`fairseq.models.FairseqLanguageModel` wrapper for + :class:`_MultiLevelLanguageModel`. + """ + def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, + open_vocab=True): + decoder = _MultiLevelLanguageModel(wordlm, subwordlm, subwordlm_weight, + oov_penalty, open_vocab) + super().__init__(decoder) + + +class _MultiLevelLanguageModel(FairseqIncrementalDecoder): + """Multi-level (subword/word) language model decoder for end-to-end ASR. + It is intended to be used for beam search decoding. + See https://ieeexplore.ieee.org/document/8268948 for details. + """ + def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, + open_vocab=True): + super().__init__(wordlm.decoder.dictionary) + + assert isinstance(wordlm, FairseqLanguageModel) + self.wordlm_decoder = wordlm.decoder + assert hasattr(self.wordlm_decoder, 'masked_copy_incremental_state') and \ + callable(self.wordlm_decoder.masked_copy_incremental_state), \ + 'The wrapped decoder should implement masked_copy_incremental_state()' + assert isinstance(subwordlm, FairseqLanguageModel) + self.subwordlm_decoder = subwordlm.decoder + self.subwordlm_weight = subwordlm_weight + self.log_oov_penalty = math.log(oov_penalty) + self.open_vocab = open_vocab + self.logzero = -10.0 + + word_dict = self.wordlm_decoder.dictionary + assert isinstance(word_dict, TokenDictionary) + self.word_eos_idx = word_dict.eos() + self.word_unk_idx = word_dict.unk() + + subword_dict = self.subwordlm_decoder.dictionary + assert isinstance(subword_dict, TokenDictionary) + self.subword_space_idx = subword_dict.space() + self.subword_eos_idx = subword_dict.eos() + self.subword_vocab_size = len(subword_dict) + + tokenizer = lambda x: tokenize( + x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) + + @torch.no_grad() + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): + assert incremental_state is not None, \ + 'this model is for incremental decoding only' + prev_output_tokens = prev_output_tokens[:, -1:] + bsz = prev_output_tokens.size(0) + + batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) + batch_not_space_mask = ~batch_space_mask + + wordlm_cached_state = utils.get_incremental_state( + self.wordlm_decoder, incremental_state, 'cached_state') + subwordlm_cached_state = utils.get_incremental_state( + self.subwordlm_decoder, incremental_state, 'cached_state') + + if wordlm_cached_state is None: # it is the first time step + assert subwordlm_cached_state is None + assert (prev_output_tokens == self.subword_eos_idx).all(), \ + 'expecting the input to the first time step to be ' + w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) + wordlm_logprobs = self.wordlm_decoder.get_normalized_probs( + self.wordlm_decoder(w, incremental_state=incremental_state), + log_probs=True, + sample=None, + ) # B x 1 x V + sw = prev_output_tokens.new_full([bsz, 1], self.subword_eos_idx) + out_logprobs = self.subwordlm_decoder.get_normalized_probs( + self.subwordlm_decoder(sw, incremental_state=incremental_state), + log_probs=True, + sample=None, + ) * self.subwordlm_weight # B x 1 x V + subword_cumlogprobs = out_logprobs.new_zeros(sw.size()) + nodes = [self.lexroot] * bsz + else: + wordlm_logprobs = utils.get_incremental_state(self, + incremental_state, 'wordlm_logprobs') + out_logprobs = utils.get_incremental_state(self, incremental_state, + 'out_logprobs') + subword_cumlogprobs = utils.get_incremental_state(self, + incremental_state, 'subword_cumlogprobs') + nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + assert len(nodes) == bsz + w = prev_output_tokens.new([ + node.word_idx if node is not None and node.word_idx >= 0 else \ + self.word_unk_idx for node in nodes + ]).unsqueeze(-1) # B x 1 + old_wordlm_cached_state = _clone_cached_state(wordlm_cached_state) + + # recompute wordlm_logprobs from inter-word transition probabilities + # only for those whose prev_output_token is + wordlm_logprobs[batch_space_mask] = self.wordlm_decoder.get_normalized_probs( + self.wordlm_decoder(w, incremental_state=incremental_state), + log_probs=True, + sample=None, + )[batch_space_mask] + self.wordlm_decoder.masked_copy_incremental_state(incremental_state, + old_wordlm_cached_state, batch_space_mask) + + tokens_list = prev_output_tokens.squeeze(-1).tolist() + token_idx, batch_is_child_mask = [], [] + for i in range(bsz): + if tokens_list[i] == self.subword_space_idx: + # inter-word transition: go back to root + nodes[i] = self.lexroot + batch_is_child_mask.append(False) + elif nodes[i] is not None and tokens_list[i] in nodes[i].children: + # intra-word transition: go to child + nodes[i] = nodes[i].children[tokens_list[i]] + token_idx.append([tokens_list[i]]) + batch_is_child_mask.append(True) + else: # no path in the tree + nodes[i] = None + if self.open_vocab: + token_idx.append([tokens_list[i]]) + batch_is_child_mask.append(False) + token_idx = prev_output_tokens.new(token_idx).unsqueeze(-1) # b x 1 x 1 + if self.open_vocab: + subword_cumlogprobs[batch_space_mask] = 0. + assert batch_not_space_mask.sum().item() == len(token_idx) + subword_cumlogprobs[batch_not_space_mask] += \ + out_logprobs[batch_not_space_mask].gather(-1, token_idx).squeeze(-1) + else: + subword_cumlogprobs[~batch_is_child_mask] = 0. + assert batch_is_child_mask.sum().item() == len(token_idx) + subword_cumlogprobs[batch_is_child_mask] += \ + out_logprobs[batch_is_child_mask].gather(-1, token_idx).squeeze(-1) + + out_logprobs = self.subwordlm_decoder.get_normalized_probs( + self.subwordlm_decoder(prev_output_tokens, incremental_state=incremental_state), + log_probs=True, + sample=None, + ) * self.subwordlm_weight + + if not self.open_vocab: + batch_oov_mask = batch_not_space_mask & ~batch_is_child_mask + out_logprobs[batch_oov_mask] = self.logzero + + utils.set_incremental_state(self, incremental_state, 'wordlm_logprobs', + wordlm_logprobs) + utils.set_incremental_state(self, incremental_state, 'subword_cumlogprobs', + subword_cumlogprobs) + utils.set_incremental_state(self, incremental_state, 'nodes', nodes) + + # apply word-level probabilies for emitting or + w = prev_output_tokens.new([ + node.word_idx if node is not None and node.word_idx >= 0 else \ + self.word_unk_idx for node in nodes + ]).unsqueeze(-1) # B x 1 + word_logprobs = wordlm_logprobs.gather(-1, w.unsqueeze(-1)).squeeze(-1) # B x 1 + batch_word_end_mask = w.ne(self.word_unk_idx) + word_logprobs += torch.where(batch_word_end_mask, + -subword_cumlogprobs, word_logprobs.new([self.log_oov_penalty])) + out_logprobs[:, :, self.subword_space_idx] = word_logprobs + out_logprobs[:, :, self.subword_eos_idx] = word_logprobs + + # set the probability of emitting or to 0 if + # prev_output_tokens is or + batch_space_eos_mask = batch_space_mask | \ + prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + out_logprobs[batch_space_eos_mask, :, self.subword_space_idx] = self.logzero + out_logprobs[batch_space_eos_mask, :, self.subword_eos_idx] = self.logzero + + # add log-probs of emitting word to that of emitting subword + cached_state = _clone_cached_state(utils.get_incremental_state( + self.wordlm_decoder, incremental_state, 'cached_state')) # for restore later + word_eos_logprobs = self.wordlm_decoder.get_normalized_probs( + self.wordlm_decoder(w, incremental_state=incremental_state), + log_probs=True, + sample=None, + )[:, :, self.word_eos_idx] + out_logprobs[:, :, self.subword_eos_idx] += word_eos_logprobs + + utils.set_incremental_state(self, incremental_state, 'out_logprobs', + out_logprobs) + utils.set_incremental_state(self.wordlm_decoder, incremental_state, + 'cached_state', cached_state) # restore decoder's state + + # note that here we return log-probs rather than logits, and the second + # element is None, which is usually a tensor of attention weights in + # attention-based models + return out_logprobs, None + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + + for state_name in ['wordlm_logprobs', 'out_logprobs', 'subword_cumlogprobs']: + state = utils.get_incremental_state(self, incremental_state, state_name) + if state is not None: + new_state = state.index_select(0, new_order) + utils.set_incremental_state(self, incremental_state, state_name, + new_state) + + nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + if nodes is not None: + new_order_list = new_order.tolist() + new_nodes = [nodes[i] for i in new_order_list] + utils.set_incremental_state(self, incremental_state, 'nodes', + new_nodes) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + # in-place op as not being used for backprop + return net_output[0] if log_probs else net_output[0].exp_() + + def max_positions(self): + return int(1e5) # an arbitrary large number diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 8f4734ad7..ac2aba446 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -9,12 +9,16 @@ import torch.nn as nn import torch.nn.functional as F -from fairseq import options, utils -from fairseq.modules import AdaptiveSoftmax, speech_attention -from . import ( - FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, - FairseqLanguageModel, register_model, register_model_architecture, +from fairseq import options, utils, checkpoint_utils +from fairseq.models import ( + FairseqEncoder, + FairseqIncrementalDecoder, + FairseqLanguageModel, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, ) +from fairseq.modules import AdaptiveSoftmax, speech_attention from .lstm import AttentionLayer, Embedding, LSTM, LSTMCell, Linear @@ -22,7 +26,7 @@ @register_model('speech_lstm') -class SpeechLSTMModel(FairseqModel): +class SpeechLSTMModel(FairseqEncoderDecoderModel): def __init__(self, encoder, decoder, pretrained_lm=None): super().__init__(encoder, decoder) self.pretrained_lm = pretrained_lm @@ -47,6 +51,10 @@ def add_args(parser): help='number of rnn encoder layers') parser.add_argument('--encoder-rnn-bidirectional', action='store_true', help='make all rnn layers of encoder bidirectional') + parser.add_argument('--encoder-rnn-residual', action='store_true', + help='create residual connections for rnn encoder ' + 'layers (starting from the 2nd layer), i.e., the actual ' + 'output of such layer is the sum of its input and output') parser.add_argument('--decoder-embed-dim', type=int, metavar='N', help='decoder embedding dimension') parser.add_argument('--decoder-embed-path', type=str, metavar='STR', @@ -59,6 +67,10 @@ def add_args(parser): help='number of decoder layers') parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', help='decoder output embedding dimension') + parser.add_argument('--decoder-rnn-residual', action='store_true', + help='create residual connections for rnn decoder ' + 'layers (starting from the 2nd layer), i.e., the actual ' + 'output of such layer is the sum of its input and output') parser.add_argument('--attention-type', type=str, metavar='STR', choices=['bahdanau','luong'], help='attention type') @@ -168,6 +180,7 @@ def eval_str_nested_list_or_tuple(x, type=int): dropout_in=args.encoder_rnn_dropout_in, dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, + residual=args.encoder_rnn_residual, ) decoder = SpeechLSTMDecoder( dictionary=task.target_dictionary, @@ -181,6 +194,7 @@ def eval_str_nested_list_or_tuple(x, type=int): attn_type=args.attention_type, attn_dim=args.attention_dim, need_attn=args.need_attention, + residual=args.decoder_rnn_residual, pretrained_embed=pretrained_decoder_embed, share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( @@ -191,7 +205,7 @@ def eval_str_nested_list_or_tuple(x, type=int): pretrained_lm = None if args.pretrained_lm_checkpoint: print('| loading pretrained LM from {}'.format(args.pretrained_lm_checkpoint)) - pretrained_lm = utils.load_ensemble_for_inference( + pretrained_lm = checkpoint_utils.load_model_ensemble( args.pretrained_lm_checkpoint, task)[0][0] pretrained_lm.make_generation_fast_() # freeze pretrained model @@ -214,8 +228,9 @@ def max_decoder_positions(self): @register_model('lstm_lm') class LSTMLanguageModel(FairseqLanguageModel): - def __init__(self, decoder): + def __init__(self, decoder, args): super().__init__(decoder) + self.is_wordlm = args.is_wordlm @staticmethod def add_args(parser): @@ -240,6 +255,11 @@ def add_args(parser): 'Must be used with adaptive_loss criterion') parser.add_argument('--share-embed', action='store_true', help='share input and output embeddings') + parser.add_argument('--is-wordlm', action='store_true', + help='whether it is word LM or subword LM. Only ' + 'relevant for ASR decoding with LM, and it determines ' + 'how the underlying decoder instance gets the dictionary' + 'from the task instance when calling cls.build_model()') # Granular dropout settings (if not specified these default to --dropout) parser.add_argument('--decoder-dropout-in', type=float, metavar='D', @@ -262,12 +282,16 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): utils.print_embed_overlap(embed_dict, dictionary) return utils.load_embedding(embed_dict, dictionary, embed_tokens) + dictionary = task.word_dictionary \ + if args.is_wordlm and hasattr(task, 'word_dictionary') \ + else task.target_dictionary + # separate decoder input embeddings pretrained_decoder_embed = None if args.decoder_embed_path: pretrained_decoder_embed = load_pretrained_embedding_from_file( args.decoder_embed_path, - task.target_dictionary, + dictionary, args.decoder_embed_dim ) # one last double check of parameter combinations @@ -282,7 +306,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_decoder_embed.weight.requires_grad = False decoder = SpeechLSTMDecoder( - dictionary=task.target_dictionary, + dictionary=dictionary, embed_dim=args.decoder_embed_dim, hidden_size=args.decoder_hidden_size, out_embed_dim=args.decoder_out_embed_dim, @@ -296,7 +320,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.criterion == 'adaptive_loss' else None ), ) - return LSTMLanguageModel(decoder) + return LSTMLanguageModel(decoder, args) class ConvBNReLU(nn.Module): @@ -345,7 +369,7 @@ def forward(self, src, src_lengths): x = x.contiguous().view(x.size(0), x.size(1), x.size(2) * x.size(3)) x_lengths = self.output_lengths(src_lengths) - padding_mask = 1 - speech_utils.sequence_mask(x_lengths, x.size(1)) + padding_mask = ~speech_utils.sequence_mask(x_lengths, x.size(1)) if padding_mask.any(): x = x.masked_fill(padding_mask.unsqueeze(-1), 0.0) @@ -357,7 +381,7 @@ class SpeechLSTMEncoder(FairseqEncoder): def __init__( self, conv_layers_before=None, input_size=80, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - left_pad=False, pretrained_embed=None, padding_value=0., + residual=False, left_pad=False, pretrained_embed=None, padding_value=0., ): super().__init__(None) # no src dictionary self.conv_layers_before = conv_layers_before @@ -366,14 +390,16 @@ def __init__( self.dropout_out = dropout_out self.bidirectional = bidirectional self.hidden_size = hidden_size + self.residual = residual - self.lstm = LSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - #dropout=self.dropout_out if num_layers > 1 else 0., - bidirectional=bidirectional, - ) + self.lstm = nn.ModuleList([ + LSTM( + input_size=input_size if layer == 0 else 2 * hidden_size if self.bidirectional else hidden_size, + hidden_size=hidden_size, + bidirectional=bidirectional, + ) + for layer in range(num_layers) + ]) self.left_pad = left_pad self.padding_value = padding_value @@ -399,7 +425,7 @@ def forward(self, src_tokens, src_lengths): src_lengths) else: x, padding_mask = src_tokens, \ - 1 - speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) + ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) bsz, seqlen = x.size(0), x.size(1) @@ -408,21 +434,28 @@ def forward(self, src_tokens, src_lengths): # B x T x C -> T x B x C x = x.transpose(0, 1) - # pack embedded source tokens into a PackedSequence - packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) - - # apply LSTM - if self.bidirectional: - state_size = 2 * self.num_layers, bsz, self.hidden_size - else: - state_size = self.num_layers, bsz, self.hidden_size - h0 = x.new_zeros(*state_size) - c0 = x.new_zeros(*state_size) - packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) - - # unpack outputs and apply dropout - x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value) - x = F.dropout(x, p=self.dropout_out, training=self.training) + state_size = (2 if self.bidirectional else 1) * self.num_layers, bsz, self.hidden_size + h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size) + final_hiddens, final_cells = x.new_empty(*state_size), x.new_empty(*state_size) + + for i in range(len(self.lstm)): + if self.residual and i > 0: # residual connection starts from the 2nd layer + prev_x = x + # pack embedded source tokens into a PackedSequence + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + + # apply LSTM + h0_i = h0[i * 2 : (i + 1) * 2] + c0_i = c0[i * 2 : (i + 1) * 2] + final_hiddens_i = final_hiddens[i * 2 : (i + 1) * 2] + final_cells_i = final_cells[i * 2 : (i + 1) * 2] + packed_outs, (final_hiddens_i, final_cells_i) = self.lstm[i](packed_x, (h0_i, c0_i)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value) + if i < len(self.lstm) - 1: # not applying dropout for the last layer + x = F.dropout(x, p=self.dropout_out, training=self.training) + x = x + prev_x if self.residual and i > 0 else x assert list(x.size()) == [seqlen, bsz, self.output_units] if self.bidirectional: @@ -461,7 +494,7 @@ class SpeechLSTMDecoder(FairseqIncrementalDecoder): def __init__( self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, encoder_output_units=0, - attn_type=None, attn_dim=0, need_attn=False, pretrained_embed=None, + attn_type=None, attn_dim=0, need_attn=False, residual=False, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, ): super().__init__(dictionary) @@ -474,6 +507,7 @@ def __init__( need_attn = False encoder_output_units = 0 self.need_attn = need_attn + self.residual = residual self.adaptive_softmax = None num_embeddings = len(dictionary) @@ -511,11 +545,38 @@ def __init__( elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) - def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None): + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for input feeding/teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - attention weights of shape `(batch, tgt_len, src_len)` + """ + x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) + x = self.output_layer(x) + return x, extra + + def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - attention weights of shape `(batch, tgt_len, src_len)` + """ if self.attention is not None: - assert encoder_out_dict is not None - encoder_out = encoder_out_dict['encoder_out'] - encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] + assert encoder_out is not None + encoder_padding_mask = encoder_out['encoder_padding_mask'] + encoder_out = encoder_out['encoder_out'] # get outputs from encoder encoder_outs = encoder_out[0] srclen = encoder_outs.size(0) @@ -555,6 +616,8 @@ def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=N for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) + if self.residual and i > 0: # residual connection starts from the 2nd layer + prev_layer_hidden = input[:, :hidden.size(1)] # compute and apply attention using the 1st layer's hidden state if self.attention is not None: @@ -568,6 +631,12 @@ def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=N else: input = hidden input = F.dropout(input, p=self.dropout_out, training=self.training) + if self.residual and i > 0: + if self.attention is not None: + hidden_sum = input[:, :hidden.size(1)] + prev_layer_hidden + input = torch.cat((hidden_sum, input[:, hidden.size(1):]), dim=1) + else: + input = input + prev_layer_hidden # save state for next time step prev_hiddens[i] = hidden @@ -598,16 +667,21 @@ def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=N else: attn_scores = None - # project back to size of vocabulary + return x, attn_scores + + def output_layer(self, features, **kwargs): + """ project features to the vocabulary size.""" if self.adaptive_softmax is None: + # project back to size of vocabulary if hasattr(self, 'additional_fc'): - x = self.additional_fc(x) - x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.additional_fc(features) + return F.dropout(x, p=self.dropout_out, training=self.training) if self.share_input_output_embed: - x = F.linear(x, self.embed_tokens.weight) + return F.linear(features, self.embed_tokens.weight) else: - x = self.fc_out(x) - return x, attn_scores + return self.fc_out(features) + else: + return features def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) @@ -623,6 +697,30 @@ def reorder_state(state): new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) + def masked_copy_incremental_state(self, incremental_state, another_cached_state, mask): + cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') + if cached_state is None: + assert another_cached_state is None + return + + def mask_copy_state(state, another_state): + if isinstance(state, list): + assert isinstance(another_state, list) and len(state) == len(another_state) + return [mask_copy_state(state_i, another_state_i) \ + for state_i, another_state_i in zip(state, another_state)] + if state is not None: + assert state.size(0) == mask.size(0) and another_state is not None and \ + state.size() == another_state.size() + for _ in range(1, len(state.size())): + mask_unsqueezed = mask.unsqueeze(-1) + return torch.where(mask_unsqueezed, state, another_state) + else: + assert another_state is None + return None + + new_state = tuple(map(mask_copy_state, cached_state, another_cached_state)) + utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) + def max_positions(self): """Maximum output length supported by the decoder.""" return int(1e5) # an arbitrary large number @@ -655,17 +753,19 @@ def Convolution2d(in_channels, out_channels, kernel_size, stride): @register_model_architecture('lstm_lm', 'lstm_lm') def base_lm_architecture(args): - args.dropout = getattr(args, 'dropout', 0.2) + args.dropout = getattr(args, 'dropout', 0.1) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) - args.decoder_hidden_size = getattr(args, 'decoder_hiden_size', 650) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 650) args.decoder_layers = getattr(args, 'decoder_layers', 2) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 650) + args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', False) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) args.share_embed = getattr(args, 'share_embed', False) + args.is_wordlm = getattr(args, 'is_wordlm', False) @register_model_architecture('lstm_lm', 'lstm_lm_wsj') @@ -673,9 +773,21 @@ def lstm_lm_wsj(args): base_lm_architecture(args) +@register_model_architecture('lstm_lm', 'lstm_wordlm_wsj') +def lstm_wordlm_wsj(args): + args.dropout = getattr(args, 'dropout', 0.3) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) + args.decoder_layers = getattr(args, 'decoder_layers', 1) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1024) + args.share_embed = getattr(args, 'share_embed', True) + args.is_wordlm = True + base_lm_architecture(args) + + @register_model_architecture('speech_lstm', 'speech_lstm') def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) + args.dropout = getattr(args, 'dropout', 0.3) args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', '[64, 64, 128, 128]') args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', @@ -685,12 +797,14 @@ def base_architecture(args): args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 320) args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) args.encoder_rnn_bidirectional = getattr(args, 'encoder_rnn_bidirectional', True) + args.encoder_rnn_residual = getattr(args, 'encoder_rnn_residual', False) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 320) args.decoder_layers = getattr(args, 'decoder_layers', 3) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 960) + args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', True) args.attention_type = getattr(args, 'attention_type', 'bahdanau') args.attention_dim = getattr(args, 'attention_dim', 320) args.need_attention = getattr(args, 'need_attention', False) @@ -704,7 +818,4 @@ def base_architecture(args): @register_model_architecture('speech_lstm', 'speech_conv_lstm_wsj') def conv_lstm_wsj(args): - args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 512) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 512) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1536) base_architecture(args) diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 45259f97f..14c5a11bf 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-present, Facebook, Inc. +# Copyright (c) 2019-present, Yiming Wang # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in @@ -13,17 +13,18 @@ @register_lr_scheduler('reduce_lr_on_plateau_v2') class ReduceLROnPlateauV2(ReduceLROnPlateau): - """Decay the LR by a factor every time the validation loss plateausi, after start_epoch_to_reduce.""" + """Decay the LR by a factor every time the validation loss plateaus, after start_epoch_to_reduce.""" def __init__(self, args, optimizer): super().__init__(args, optimizer) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=0, factor=args.lr_shrink, - min_lr=args.min_lr) + threshold=args.lr_threshold, min_lr=args.min_lr) @staticmethod def add_args(parser): """Add arguments to the parser for this LR scheduler.""" + ReduceLROnPlateau.add_args(parser) # fmt: off parser.add_argument('--start-reduce-lr-epoch', default=0, type=int, metavar='N', help='start to reduce lr from specified epoch') diff --git a/fairseq/tasks/language_modeling_for_asr.py b/fairseq/tasks/language_modeling_for_asr.py index 0745e21ef..5270257ff 100644 --- a/fairseq/tasks/language_modeling_for_asr.py +++ b/fairseq/tasks/language_modeling_for_asr.py @@ -11,11 +11,10 @@ from fairseq import tokenizer from fairseq.data import TokenDictionary +from fairseq.tasks import register_task from .language_modeling import LanguageModelingTask -from . import register_task - @register_task('language_modeling_for_asr') class LanguageModelingForASRTask(LanguageModelingTask): @@ -56,7 +55,7 @@ def add_args(parser): help='path to the dictionary') # fmt: on - def __init__(self, args, dictionary, output_dictionary, targets=None): + def __init__(self, args, dictionary, output_dictionary=None, targets=None): super().__init__(args, dictionary, output_dictionary, targets=targets) torch.backends.cudnn.deterministic = True @@ -99,7 +98,9 @@ def setup_task(cls, args, **kwargs): dictionary = None output_dictionary = None if args.data: - dict_path = os.path.join(args.data, 'dict.txt') if args.dict is None \ + paths = args.data.split(':') + assert len(paths) > 0 + dict_path = os.path.join(paths[0], 'dict.txt') if args.dict is None \ else args.dict dictionary = TokenDictionary.load(dict_path) print('| dictionary: {} types'.format(len(dictionary))) diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index f21e3257f..c82986dc1 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -11,7 +11,7 @@ import os import re -from fairseq import options, utils +from fairseq import options from fairseq.data import ( ConcatDataset, data_utils, @@ -30,7 +30,7 @@ class SpeechRecognitionTask(FairseqTask): Transcribe from speech (source) to token text (target). Args: - dict (Dictionary): dictionary for the output tokens + dict (~fairseq.data.TokenDictionary): dictionary for the output tokens .. note:: @@ -50,10 +50,12 @@ def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off parser.add_argument('--train-feat-files', nargs='+', - help='path(s) to scp feature file(s) for training') + help='path(s) to scp feature file(s) for training, ' + 'will be iterated upon during epochs in round-robin manner') parser.add_argument('--train-text-files', nargs='+', help='path(s) to text file(s) for training, where ' - 'each should matches with one in --train-feat-files') + 'each should matches with one in --train-feat-files, ' + 'will be iterated upon during epochs in round-robin manner') parser.add_argument('--valid-feat-files', nargs='+', help='path(s) to scp feature file(s) for validation') parser.add_argument('--valid-text-files', nargs='+', @@ -61,14 +63,21 @@ def add_args(parser): 'each should matches with one in --valid-feat-files') parser.add_argument('--test-feat-files', nargs='+', help='path(s) to scp feature file(s) for test') - parser.add_argument('--test-text-files', nargs='+', - help='path(s) to text file(s) for test, where ' - 'each should matches with one in --test-feat-files') + parser.add_argument('--test-text-files', nargs='*', default=None, + help='path(s) to text file(s) for test. if not None, ' + 'each one should matches with one in --test-feat-files') + parser.add_argument('--train-subset-feat-files', nargs='+', + help='path(s) to scp feature file(s) for validation') + parser.add_argument('--train-subset-text-files', nargs='+', + help='path(s) to text file(s) for validation, where ' + 'each should matches with one in --train-subset-feat-files') parser.add_argument('--dict', default=None, type=str, help='path to the dictionary') parser.add_argument('--non-lang-syms', default=None, type=str, - help='list of non-linguistic symbols, e.g., ' - 'etc. To be filtered out when calculating WER/CER') + help='path to a file listing non-linguistic symbols, e.g., ' + 'etc. One entry per line. To be filtered out when calculating WER/CER.') + parser.add_argument('--word-dict', default=None, type=str, + help='path to the word dictionary. Only relevant for decoding') parser.add_argument('--wer-output-filter', default=None, type=str, help='path to wer_output_filter file for WER evaluation') parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', @@ -100,24 +109,10 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding """ raise NotImplementedError - @staticmethod - def load_pretrained_model(path, dict_path, non_lang_syms=None, - arg_overrides=None): - model = utils.load_checkpoint_to_cpu(path) - args = model['args'] - state_dict = model['model'] - args = utils.override_model_args(args, arg_overrides) - dict = cls.load_dictionary(dict_path, non_lang_syms=non_lang_syms) - - task = SpeechRecognitionTask(args, dict) - model = task.build_model(args) - model.upgrade_state_dict(state_dict) - model.load_state_dict(state_dict, strict=True) - return model - - def __init__(self, args, dict): + def __init__(self, args, dict, word_dict=None): super().__init__(args) self.dict = dict + self.word_dict = word_dict self.feat_in_channels = args.feat_in_channels torch.backends.cudnn.deterministic = True @@ -132,14 +127,20 @@ def setup_task(cls, args, **kwargs): args.left_pad_target = options.eval_bool(args.left_pad_target) # load dictionaries - dict_path = os.path.join(os.path.dirname(args.text_files[0]), - 'dict.txt') if args.dict is None else args.dict + dict_path = os.path.join(os.path.dirname(args.text_files[0]), 'dict.txt') \ + if args.dict is None and args.text_files is not None else args.dict + assert dict_path is not None, 'Please specify --dict' dict = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) print('| dictionary: {} types'.format(len(dict))) + if args.word_dict is not None: + word_dict = cls.load_dictionary(args.word_dict) + print('| word dictionary: {} types'.format(len(word_dict))) + return cls(args, dict, word_dict) - return cls(args, dict) + else: + return cls(args, dict) - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. Args: @@ -151,40 +152,45 @@ def load_dataset(self, split, combine=False, **kwargs): if split == 'train': feat_files = self.args.train_feat_files text_files = self.args.train_text_files - assert len(feat_files) > 0 and len(text_files) > 0 - elif re.match(r"^valid\d*$", split): - m = re.match(r"^valid(\d*)$", split) - idx = 0 if m.group(1) == '' else int(m.group(1)) - if idx >= len(self.args.valid_feat_files) or \ - idx >= len(self.args.valid_text_files): - raise FileNotFoundError - feat_files = [self.args.valid_feat_files[idx]] - text_files = [self.args.valid_text_files[idx]] - assert len(feat_files) > 0 and len(text_files) > 0 + assert len(feat_files) > 0 and len(feat_files) == len(text_files) + feat_files = [feat_files[epoch % len(feat_files)]] + text_files = [text_files[epoch % len(text_files)]] + elif split == 'valid': + feat_files = self.args.valid_feat_files + text_files = self.args.valid_text_files elif split == 'test': feat_files = self.args.test_feat_files - text_files = self.args.test_text_files - assert len(feat_files) > 0 and len(text_files) > 0 + text_files = self.args.test_text_files # can be empty + if text_files is None: + text_files = [None] * len(feat_files) + elif split == 'train_subset': + feat_files = self.args.train_subset_feat_files + text_files = self.args.train_subset_text_files else: - raise ValueError('split should be one of "train", "valid*", "test"') - assert len(feat_files) == len(text_files) + raise ValueError('split should be one of "train", "valid", "test", "train_subset"') + + assert len(feat_files) > 0 and len(feat_files) == len(text_files) file_pairs = zip(feat_files, text_files) for feat, text in file_pairs: - assert ScpCachedDataset.exists(feat) and TokenTextDataset.exists(text) + assert ScpCachedDataset.exists(feat) + assert text is None or TokenTextDataset.exists(text) src_datasets.append(ScpCachedDataset(feat, ordered_prefetch=True)) - tgt_datasets.append(TokenTextDataset(text, self.dict)) print('| {} {} examples'.format(feat, len(src_datasets[-1]))) - print('| {} {} examples'.format(text, len(tgt_datasets[-1]))) + if text is not None: + tgt_datasets.append(TokenTextDataset(text, self.dict)) + print('| {} {} examples'.format(text, len(tgt_datasets[-1]))) if not combine: break - assert len(src_datasets) == len(tgt_datasets) + if len(tgt_datasets) > 0: + assert len(src_datasets) == len(tgt_datasets) self.feat_dim = src_datasets[0].feat_dim if len(src_datasets) == 1: - src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] + src_dataset = src_datasets[0] + tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): assert self.feat_dim == src_datasets[i].feat_dim, \ @@ -192,32 +198,61 @@ def load_dataset(self, split, combine=False, **kwargs): sample_ratios = [1] * len(src_datasets) sample_ratios[0] = self.args.upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) - tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) \ + if len(tgt_datasets) > 0 else None self.datasets[split] = SpeechDataset( src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset.sizes, self.dict, + tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, + self.dict, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ) + # update the counts of and in dictionary with training data + if split == 'train': + self.dict.count[self.dict.eos()] = len(tgt_dataset) + unk_count = 0 + for i in range(len(tgt_dataset)): + unk_count += (tgt_dataset[i] == self.dict.unk()).int().sum().item() + self.dict.count[self.dict.unk()] = unk_count + def build_generator(self, args): if args.score_reference: args.score_reference = False print('| --score-reference is not applicable to speech recognition,' ' ignoring it.') - return super().build_generator(args) + from fairseq.sequence_generator import SequenceGenerator + return SequenceGenerator( + self.target_dictionary, + beam_size=args.beam, + max_len_a=args.max_len_a, + max_len_b=args.max_len_b, + min_len=args.min_len, + stop_early=(not args.no_early_stop), + normalize_scores=(not args.unnormalized), + len_penalty=args.lenpen, + unk_penalty=args.unkpen, + sampling=args.sampling, + sampling_topk=args.sampling_topk, + temperature=args.temperature, + diverse_beam_groups=args.diverse_beam_groups, + diverse_beam_strength=args.diverse_beam_strength, + match_source_len=args.match_source_len, + no_repeat_ngram_size=args.no_repeat_ngram_size, + coverage_weight=args.coverage_weight, + ) def build_dataset_for_inference(self, src_tokens, src_lengths): return SpeechDataset(src_tokens, src_lengths) def inference_step(self, generator, models, sample, prefix_tokens=None, - lprob_weights=None): + lm_weight=0.0): with torch.no_grad(): return generator.generate(models, sample, prefix_tokens=prefix_tokens, - lprob_weights=lprob_weights) + lm_weight=lm_weight) def max_positions(self): """Return the max sentence length allowed by the task.""" @@ -225,5 +260,10 @@ def max_positions(self): @property def target_dictionary(self): - """Return the target :class:`~fairseq.data.Dictionary`.""" + """Return the target :class:`~fairseq.data.TokenDictionary`.""" return self.dict + + @property + def word_dictionary(self): + """Return the target :class:`~fairseq.data.TokenDictionary`.""" + return self.word_dict diff --git a/fairseq/wer.py b/fairseq/wer.py index 03a541e94..abc60973c 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -23,6 +23,7 @@ def __init__(self, dict, wer_output_filter=None): def reset(self): self.char_counter = Counter() self.word_counter = Counter() + self.char_results = OrderedDict() self.results = OrderedDict() self.aligned_results = OrderedDict() @@ -42,7 +43,8 @@ def parse_wer_output_filter(self, wer_output_filter): assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: - print('Unsupported pattern: "' + line + '", ignored') + print('Unsupported pattern: "{}", ignored'.format(line), + file=sys.stderr) def add_prediction(self, utt_id, pred): if not isinstance(utt_id, str): @@ -50,6 +52,10 @@ def add_prediction(self, utt_id, pred): if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) + assert not utt_id in self.char_results, \ + 'Duplicated utterance id detected: {}'.format(utt_id) + self.char_results[utt_id] = pred + '\n' + pred_words = self.dict.tokens_to_sentence(pred) assert not utt_id in self.results, \ 'Duplicated utterance id detected: {}'.format(utt_id) @@ -132,11 +138,24 @@ def add_ordered_utt_list(self, *args): with open(text_file, 'r', encoding='utf-8') as f: one_utt_list = [line.strip().split()[0] for line in f] self.ordered_utt_list.extend(one_utt_list) + if len(self.char_results): + assert set(self.ordered_utt_list) == set(self.char_results.keys()) if len(self.results): assert set(self.ordered_utt_list) == set(self.results.keys()) if len(self.aligned_results): assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) + def print_char_results(self): + res = '' + if self.ordered_utt_list is not None: + assert set(self.ordered_utt_list) == set(self.char_results.keys()) + for utt_id in self.ordered_utt_list: + res += utt_id + ' ' + self.char_results[utt_id] + else: + for utt_id in self.char_results: + res += utt_id + ' ' + self.char_results[utt_id] + return res + def print_results(self): res = '' if self.ordered_utt_list is not None: diff --git a/speech_recognize.py b/speech_recognize.py index 4ac27b07b..b617b6050 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -14,8 +14,10 @@ import torch -from fairseq import wer, options, progress_bar, tasks, utils +from fairseq import wer, checkpoint_utils, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter, TimeMeter +from fairseq.models import FairseqLanguageModel +from fairseq.models.external_language_model import LookAheadWordLanguageModel, MultiLevelLanguageModel from fairseq.utils import import_user_module from speech_tools.utils import plot_attention @@ -25,7 +27,7 @@ def main(args): assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' - import_user_module(args) + utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 @@ -36,19 +38,37 @@ def main(args): # Load dataset split task = tasks.setup_task(args) task.load_dataset(args.gen_subset) - print('| {} {} examples'.format(args.gen_subset, - len(task.dataset(args.gen_subset)))) # Set dictionary dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) - models, _model_args = utils.load_ensemble_for_inference( - args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides), + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(':'), + arg_overrides=eval(args.model_overrides), + task=task, ) - if args.lprob_weights is not None: - print('| using model ensemble with lprob-weights={}'.format(str(args.lprob_weights))) + for i, m in enumerate(models): + if hasattr(m, 'is_wordlm') and m.is_wordlm: + # assume subword LM comes before word LM + if isinstance(models[i - 1], FairseqLanguageModel): + models[i-1] = MultiLevelLanguageModel(m, models[i-1], + subwordlm_weight=args.subwordlm_weight, + oov_penalty=args.oov_penalty, + open_vocab=not args.disable_open_vocab) + del models[i] + print('| LM fusion with Multi-level LM') + else: + models[i] = LookAheadWordLanguageModel(m, dict, + oov_penalty=args.oov_penalty, + open_vocab=not args.disable_open_vocab) + print('| LM fusion with Look-ahead Word LM') + # assume subword LM comes after E2E models + elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel): + print('| LM fusion with Subword LM') + if args.lm_weight != 0.0: + print('| using LM fusion with lm-weight={:.2f}'.format(args.lm_weight)) # Optimize ensemble for generation for model in models: @@ -102,7 +122,7 @@ def main(args): gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens, - lprob_weights=args.lprob_weights) + lm_weight=args.lm_weight) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) @@ -144,15 +164,22 @@ def main(args): t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] - print('| Recognized {} utterances in {:.1f}s ({:.2f} utterances/s)'.format( - num_sentences, gen_timer.sum, 1. / gen_timer.avg)) + print('| Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( + num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.print_alignment: print('| Saved attention plots in ' + save_dir) - scorer.add_ordered_utt_list(*args.test_text_files) + if has_target: + assert args.test_text_files is not None + scorer.add_ordered_utt_list(*args.test_text_files) os.makedirs(args.results_path, exist_ok=True) + fn = 'decoded_char_results.txt' + with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: + f.write(scorer.print_char_results()) + print('| Decoded char results saved as ' + f.name) + fn = 'decoded_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_results()) @@ -189,15 +216,26 @@ def print_options_meaning_changes(args): """ print('| --max-tokens is the maximum number of input frames in a batch') if args.print_alignment: - print('| --print-alignment is set to plot attentions') + print('| --print-alignment has been set to plot attentions') def cli_main(): parser = options.get_generation_parser(default_task='speech_recognition') - parser.add_argument('--lprob-weights', default=None, type=options.eval_str_list, - metavar='W_1,W_2,...,W_N', - help='model ensemble weights in log-prob space, the same' - 'length as number of models specified in --path') + parser.add_argument('--coverage-weight', default=0.0, type=float, metavar='W', + help='coverage weight in log-prob space, mostly to ' + 'reduce deletion errors while using the pretrained ' + 'external LM for decoding') + parser.add_argument('--lm-weight', default=0.0, type=float, metavar='W', + help='LM weight in log-prob space, assuming the pretrained ' + 'external LM is specified as the second one in --path') + parser.add_argument('--subwordlm-weight', default=0.8, type=float, metavar='W', + help='subword LM weight relative to word LM. Only relevant ' + 'to MultiLevelLanguageModel as an external LM') + parser.add_argument('--oov-penalty', default=1e-4, type=float, + help='oov penalty with the pretrained external LM') + parser.add_argument('--disable-open-vocab', action='store_true', + help='whether open vocabulary mode is enabled with the ' + 'pretrained external LM') args = options.parse_args_and_arch(parser) assert args.results_path is not None, 'please specify --results-path' print_options_meaning_changes(args) diff --git a/speech_tools/Makefile b/speech_tools/Makefile index a9f4035b0..5fe67d1c0 100644 --- a/speech_tools/Makefile +++ b/speech_tools/Makefile @@ -6,7 +6,7 @@ all: kaldi kaldi-io-for-python kaldi-io-for-python: git clone https://github.com/vesis84/kaldi-io-for-python.git - ln -nfs kaldi-io-for-python/kaldi_io/kaldi_io.py kaldi_io.py + ln -sf kaldi-io-for-python/kaldi_io/kaldi_io.py kaldi_io.py ifneq ($(strip $(KALDI)),) kaldi: diff --git a/speech_tools/compute_wer.py b/speech_tools/compute_wer.py new file mode 100755 index 000000000..e0ae5257c --- /dev/null +++ b/speech_tools/compute_wer.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2019-present, Yiming Wang +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import argparse +import sys, re +from collections import Counter + +from utils import edit_distance + + +def get_parser(): + parser = argparse.ArgumentParser( + description='Compute WER from text') + # fmt: off + parser.add_argument('--non-lang-syms', default=None, type=str, + help='path to a file listing non-linguistic symbols, ' + 'e.g., etc. One entry per line.') + parser.add_argument('--wer-output-filter', default=None, type=str, + help='path to wer_output_filter file for WER evaluation') + parser.add_argument('ref_text', type=str, + help='path to the reference text file') + parser.add_argument('hyp_text', type=str, + help='path to the hypothesis text file') + + # fmt: on + + return parser + + +def main(args): + non_lang_syms = [] + if args.non_lang_syms is not None: + with open(args.non_lang_syms, 'r', encoding='utf-8') as f: + non_lang_syms = [x.rstrip() for x in f.readlines()] + + word_filters = [] + if args.wer_output_filter is not None: + with open(args.wer_output_filter, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line.startswith('#!') or line == '': + continue + elif line.startswith('s/'): + m = re.match(r's/(\S+)/(\w*)/g', line) + assert m is not None + word_filters.append([m.group(1), m.group(2)]) + elif line.startswith('s:'): + m = re.match(r's:(\S+):(\w*):g', line) + assert m is not None + word_filters.append([m.group(1), m.group(2)]) + else: + print('Unsupported pattern: "{}", ignored'.format(line), + file=sys.stderr) + + refs = {} + with open(args.ref_text, 'r', encoding='utf-8') as f: + for line in f: + utt_id, text = line.strip().split(None, 1) + assert utt_id not in refs, utt_id + refs[utt_id] = text + + wer_counter = Counter() + with open(args.hyp_text, 'r', encoding='utf-8') as f: + for line in f: + utt_id, text = line.strip().split(None, 1) + assert utt_id in refs, utt_id + ref, hyp = refs[utt_id], text + + # filter words according to word_filters (support re.sub only) + for pattern, repl in word_filters: + ref = re.sub(pattern, repl, ref) + hyp = re.sub(pattern, repl, hyp) + + # filter out any non_lang_syms from ref and hyp + ref_list = [x for x in ref.split() if x not in non_lang_syms] + hyp_list = [x for x in hyp.split() if x not in non_lang_syms] + + _, _, counter = edit_distance(ref_list, hyp_list) + wer_counter += counter + + assert wer_counter['words'] > 0 + wer = float(wer_counter['sub'] + wer_counter['ins'] + \ + wer_counter['del']) / wer_counter['words'] * 100 + sub = float(wer_counter['sub']) / wer_counter['words'] * 100 + ins = float(wer_counter['ins']) / wer_counter['words'] * 100 + dlt = float(wer_counter['del']) / wer_counter['words'] * 100 + + print('WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}'.format( + wer, sub, ins, dlt, wer_counter['words'])) + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/speech_tools/text2token.py b/speech_tools/text2token.py index 06571237b..6b8a201af 100755 --- a/speech_tools/text2token.py +++ b/speech_tools/text2token.py @@ -21,7 +21,8 @@ def get_parser(): parser.add_argument('--space', default='', type=str, help='space symbol') parser.add_argument('--non-lang-syms', default=None, type=str, - help='list of non-linguistic symobles, e.g., etc.') + help='path to a file listing non-linguistic symbols, ' + 'e.g., etc. One entry per line.') parser.add_argument('text', type=str, nargs='?', help='input text') # fmt: on diff --git a/speech_tools/text2vocabulary.py b/speech_tools/text2vocabulary.py new file mode 100755 index 000000000..3b43e1ce6 --- /dev/null +++ b/speech_tools/text2vocabulary.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2019-present, Yiming Wang +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import argparse +import sys +from collections import Counter + + +def get_parser(): + parser = argparse.ArgumentParser( + description='Create a vocabulary from text files') + # fmt: off + parser.add_argument('--skip-ncols', default=0, type=int, + help='skip first n columns') + parser.add_argument('--cutoff', default=0, type=int, + help='cut-off frequency') + parser.add_argument('--vocabsize', default=20000, type=int, + help='vocabulary size') + parser.add_argument('--exclude', type=str, default=None, + help='space separated, list of excluding words, ' + 'e.g., etc.') + parser.add_argument('--valid-text', type=str, default=None, + help='path to the validation text') + parser.add_argument('--test-text', type=str, default=None, + help='path to the test text') + parser.add_argument('text_files', nargs='*', + help='input text files') + # fmt: on + + return parser + + +def main(args): + exclude = args.exclude.split(' ') if args.exclude is not None else [] + if len(args.text_files) == 0: + args.text_files.append('-') + + counter = Counter() + for fn in args.text_files: + with (open(fn, 'r', encoding='utf-8') if fn != '-' else sys.stdin) as f: + for line in f: + tokens = line.rstrip().split()[args.skip_ncols:] + tokens = [tok for tok in tokens if tok not in exclude] + counter.update(tokens) + + total_count = sum(counter.values()) + most_common = counter.most_common(args.vocabsize) + cutoff_point = 0 + invocab_count = 0 + for elem in most_common: + if elem[1] < args.cutoff: + break + invocab_count += elem[1] + cutoff_point += 1 + cutoff_freq = most_common[cutoff_point - 1][1] + most_common = most_common[:cutoff_point] + + oov_rate = 1. - float(invocab_count) / total_count + print('training set:', file=sys.stderr) + print(' total #tokens={:d}'.format(total_count), file=sys.stderr) + print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) + print(' cutoff frequency={:d}'.format(cutoff_freq), file=sys.stderr) + + # words in vocabulary are lexically sorted + for w, c in sorted(most_common, key=lambda x: x[0]): + print('{} {:d}'.format(w, c)) + + vocab_set = set(list(zip(*most_common))[0]) + if args.valid_text is not None: + total_count = 0 + invocab_count = 0 + with open(args.valid_text, 'r', encoding='utf-8') as f: + for line in f: + tokens = line.rstrip().split()[args.skip_ncols:] + tokens = [tok for tok in tokens if tok not in exclude] + total_count += len(tokens) + invocab_count += len([tok for tok in tokens if tok in vocab_set]) + oov_rate = 1. - float(invocab_count) / total_count + print('validation set:', file=sys.stderr) + print(' total #tokens={:d}'.format(total_count), file=sys.stderr) + print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) + + if args.test_text is not None: + total_count = 0 + invocab_count = 0 + with open(args.test_text, 'r', encoding='utf-8') as f: + for line in f: + tokens = line.rstrip().split()[args.skip_ncols:] + tokens = [tok for tok in tokens if tok not in exclude] + total_count += len(tokens) + invocab_count += len([tok for tok in tokens if tok in vocab_set]) + oov_rate = 1. - float(invocab_count) / total_count + print('test set:', file=sys.stderr) + print(' total #tokens={:d}'.format(total_count), file=sys.stderr) + print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 2b6843f51..db42b6b22 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -11,7 +11,7 @@ import torch -from fairseq.utils import buffered_arange, item +from fairseq import utils def tokenize(sent, space='', non_lang_syms=None): @@ -56,7 +56,7 @@ def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() else: - assert item(sequence_length.data.max()) <= item(max_len) + assert sequence_length.data.max().item() <= utils.item(max_len) batch_size = sequence_length.size(0) seq_range = torch.arange(0, max_len).to(device=sequence_length.device, dtype=sequence_length.dtype) @@ -78,7 +78,7 @@ def convert_padding_direction(src_frames, src_lengths, right_to_left=False, if not src_lengths.eq(max_len).any(): # no padding, return early return src_frames - range = buffered_arange(max_len).unsqueeze(-1).expand_as(src_frames) + range = utils.buffered_arange(max_len).unsqueeze(-1).expand_as(src_frames) num_pads = (max_len - src_lengths.type_as(range)).unsqueeze(-1).unsqueeze(-1) if right_to_left: index = torch.remainder(range - num_pads, max_len) @@ -250,5 +250,57 @@ def aligned_print(ref, hyp, steps): wer = float(counter['sub'] + counter['ins'] + counter['del']) / len(ref) \ * 100 out_str += 'WER: ' + '{:.2f}%'.format(wer) + '\n' + out_str += '\n' return out_str + +def lexical_prefix_tree(word_dict, subword_dict, subword_tokenizer=None): + """Build a lexical prefix tree for words. + + Args: + word_dict: an instance of :class:`fairseq.data.TokenDictionary`. + subword_dict: an instance of :class:`fairseq.data.TokenDictionary`. + subword_tokenizer (callable): a function that takes a word string as its + only one argument, and returns a list of subwords as a result of + tokenization. + + Return: + root (Node): the root of the prefix tree, where each node has the fields: + ('children': Dict[int,Node], 'word_idx': int, 'word_set': Tuple[int]). + 'children' is subword_idx -> node, and 'word_set' is (first-1, last), + where [first, last] is the range of the word indexes (inclusive) in + the word dictionary who share the same prefix at that node. + We assume words in the word dictionary are in lexical order. + """ + + class Node(object): + def __init__(self, children={}, word_idx=-1, word_set=None): + self.children = children + self.word_idx = word_idx + self.word_set = word_set + + special_symbols = [word_dict.pad(), word_dict.eos(), word_dict.unk()] + assert 0 in special_symbols # to ensure widx - 1 >= 0 + root = Node({}, -1, None) + for widx in range(len(word_dict)): + if widx not in special_symbols: # skip , , + # tokenize a word into a list of subwords + subwords = subword_tokenizer(word_dict[widx]) \ + if subword_tokenizer is not None else list(word_dict[widx]) + if any(subword_dict.index(s) == subword_dict.unk() for s in subwords): + # skip words containing any unknown subwords + continue + children = root.children + for i, s in enumerate(subwords): + sidx = subword_dict.index(s) + if sidx not in children: # make a new node + children[sidx] = Node({}, -1, (widx - 1, widx)) + else: + children[sidx].word_set = ( + min(children[sidx].word_set[0], widx - 1), + max(children[sidx].word_set[1], widx) + ) + if i == len(subwords) - 1: # if word end, set word_idx + children[sidx].word_idx = widx + children = children[sidx].children # move to children + return root diff --git a/speech_train.py b/speech_train.py index 9ea3b9c93..81eb95ed4 100755 --- a/speech_train.py +++ b/speech_train.py @@ -11,14 +11,13 @@ """ import collections -import itertools -import os import math +import os import random import torch -from fairseq import distributed_utils, options, progress_bar, tasks, utils +from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils from fairseq.data import iterators from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter @@ -26,23 +25,27 @@ def main(args, init_distributed=False): - import_user_module(args) + utils.import_user_module(args) - if args.max_tokens is None: - args.max_tokens = 6000 - print(args) + assert args.max_tokens is not None or args.max_sentences is not None, \ + 'Must specify batch size either with --max-tokens or --max-sentences' + # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) - if args.disable_cudnn: - torch.backends.cudnn.enabled = False + if init_distributed: + args.distributed_rank = distributed_utils.distributed_init(args) + + # Print args + print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) - # Load dataset splits - load_dataset_splits(args, task) + # Load valid dataset (we load training data below, based on the latest checkpoint) + for valid_sub_split in args.valid_subset.split(','): + task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build model and criterion model = task.build_model(args) @@ -54,47 +57,20 @@ def main(args, init_distributed=False): sum(p.numel() for p in model.parameters() if p.requires_grad), )) - # Make a dummy batch to (i) warm the caching allocator and (ii) as a - # placeholder DistributedDataParallel when there's an uneven number of - # batches per worker. - max_positions = utils.resolve_max_positions( - task.max_positions(), - model.max_positions(), - ) - dummy_batch = task.dataset(args.train_subset).get_dummy_batch(args.max_tokens, max_positions) - oom_batch = task.dataset(args.train_subset).get_dummy_batch(1, max_positions) - # Build trainer - trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch) + trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max input frames per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) - # Initialize dataloader - epoch_itr = task.get_batch_iterator( - dataset=task.dataset(args.train_subset), - max_tokens=args.max_tokens, - max_sentences=args.max_sentences, - max_positions=max_positions, - ignore_invalid_inputs=True, - required_batch_size_multiple=8, - seed=args.seed, - num_shards=args.distributed_world_size, - shard_id=args.distributed_rank, - num_workers=args.num_workers, - ) + # Load the latest checkpoint if one is available and restore the + # corresponding train iterator + extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) - # Initialize distributed training (after data loading) - if init_distributed: - import socket - args.distributed_rank = distributed_utils.distributed_init(args) - print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank)) - - # Load the latest checkpoint if one is available - if not load_checkpoint(args, trainer, epoch_itr): - trainer.dummy_train_step([dummy_batch]) + if callable(getattr(trainer.criterion, 'set_train_tgt_dataset', None)): + trainer.criterion.set_train_tgt_dataset(task.dataset(args.train_subset).tgt) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -104,13 +80,13 @@ def main(args, init_distributed=False): train_meter.start() valid_losses, valid_wers = [None], [None] valid_subsets = args.valid_subset.split(',') - while lr >= args.min_lr and (epoch_itr.epoch < max_epoch or \ + while lr > args.min_lr and (epoch_itr.epoch < max_epoch or \ (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and \ trainer.get_num_updates() < max_update: # train for one epoch train(args, trainer, task, epoch_itr) - if epoch_itr.epoch % args.validate_interval == 0: + if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, valid_subsets) # only use first validation wer to update the learning rate @@ -118,7 +94,11 @@ def main(args, init_distributed=False): # save checkpoint if epoch_itr.epoch % args.save_interval == 0: - save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) + + if len(args.train_feat_files) > 1: + # sharded data: get train iterator for next epoch + epoch_itr = trainer.get_train_iterator(epoch_itr.epoch) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum)) @@ -128,7 +108,7 @@ def train(args, trainer, task, epoch_itr): # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ - if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] + if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( @@ -141,7 +121,7 @@ def train(args, trainer, task, epoch_itr): ) extra_meters = collections.defaultdict(lambda: AverageMeter()) - first_valid = args.valid_subset.split(',')[0] + valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if callable(getattr(trainer.criterion, 'set_num_updates', None)): @@ -168,9 +148,14 @@ def train(args, trainer, task, epoch_itr): trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() - if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: - valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, [first_valid]) - save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) + if ( + not args.disable_validation + and args.save_interval_updates > 0 + and num_updates % args.save_interval_updates == 0 + and num_updates > 0 + ): + valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, valid_subsets) + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) if num_updates >= max_update: break @@ -198,7 +183,7 @@ def get_training_stats(trainer): stats['nll_loss'] = nll_loss else: nll_loss = trainer.get_meter('train_loss') - stats['ppl'] = get_perplexity(nll_loss.avg) + stats['ppl'] = utils.get_perplexity(nll_loss.avg) stats['wps'] = trainer.get_meter('wps') stats['ups'] = trainer.get_meter('ups') stats['wpb'] = trainer.get_meter('wpb') @@ -274,8 +259,8 @@ def validate(args, trainer, task, epoch_itr, subsets): stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg - if hasattr(save_checkpoint, 'best'): - stats['best_wer'] = min(save_checkpoint.best, stats['wer']) + if hasattr(checkpoint_utils.save_checkpoint, 'best'): + stats['best_wer'] = min(checkpoint_utils.save_checkpoint.best, stats['wer']) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats['loss'].avg) @@ -291,118 +276,15 @@ def get_valid_stats(trainer): stats['nll_loss'] = nll_loss else: nll_loss = stats['loss'] - stats['ppl'] = get_perplexity(nll_loss.avg) + stats['ppl'] = utils.get_perplexity(nll_loss.avg) stats['num_updates'] = trainer.get_num_updates() return stats -def get_perplexity(loss): - try: - return '{:.2f}'.format(math.pow(2, loss)) - except OverflowError: - return float('inf') - - -def save_checkpoint(args, trainer, epoch_itr, val_wer): - if args.no_save or not distributed_utils.is_master(args): - return - epoch = epoch_itr.epoch - end_of_epoch = epoch_itr.end_of_epoch() - updates = trainer.get_num_updates() - - checkpoint_conds = collections.OrderedDict() - checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( - end_of_epoch and not args.no_epoch_checkpoints and - epoch % args.save_interval == 0 - ) - checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( - not end_of_epoch and args.save_interval_updates > 0 and - updates % args.save_interval_updates == 0 - ) - checkpoint_conds['checkpoint_best.pt'] = ( - val_wer is not None and - (not hasattr(save_checkpoint, 'best') or val_wer < save_checkpoint.best) - ) - checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink - - prev_best = getattr(save_checkpoint, 'best', val_wer) - if val_wer is not None: - save_checkpoint.best = min(val_wer, prev_best) - extra_state = { - 'train_iterator': epoch_itr.state_dict(), - 'val_wer': val_wer, - } - if hasattr(save_checkpoint, 'best'): - extra_state.update({'best': save_checkpoint.best}) - - checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] - if len(checkpoints) > 0: - for cp in checkpoints: - trainer.save_checkpoint(cp, extra_state) - - if not end_of_epoch and args.keep_interval_updates > 0: - # remove old checkpoints; checkpoints are sorted in descending order - checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') - for old_chk in checkpoints[args.keep_interval_updates:]: - if os.path.lexists(old_chk): - os.remove(old_chk) - - if args.keep_last_epochs > 0: - # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') - for old_chk in checkpoints[args.keep_last_epochs:]: - if os.path.lexists(old_chk): - os.remove(old_chk) - - -def load_checkpoint(args, trainer, epoch_itr): - """Load a checkpoint and replay dataloader to match.""" - - # Only rank 0 should attempt to create the required dir - if args.distributed_rank == 0: - os.makedirs(args.save_dir, exist_ok=True) - - if os.path.isabs(args.restore_file): - checkpoint_path = args.restore_file - else: - checkpoint_path = os.path.join(args.save_dir, args.restore_file) - if os.path.isfile(checkpoint_path): - extra_state = trainer.load_checkpoint(checkpoint_path, args.reset_optimizer, args.reset_lr_scheduler, - eval(args.optimizer_overrides)) - if extra_state is not None: - # replay train iterator to match checkpoint - epoch_itr.load_state_dict(extra_state['train_iterator']) - - print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( - checkpoint_path, epoch_itr.epoch, trainer.get_num_updates())) - - trainer.lr_step(epoch_itr.epoch) - trainer.lr_step_update(trainer.get_num_updates()) - if 'best' in extra_state: - save_checkpoint.best = extra_state['best'] - return True - else: - print('| no existing checkpoint found {}'.format(checkpoint_path)) - return False - - -def load_dataset_splits(args, task): - task.load_dataset(args.train_subset, combine=True) - for split in args.valid_subset.split(','): - for k in itertools.count(): - split_k = split + (str(k) if k > 0 else '') - try: - task.load_dataset(split_k, combine=False) - except FileNotFoundError as e: - if k > 0: - break - raise e - - -def distributed_main(i, args): +def distributed_main(i, args, start_rank=0): args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = i + args.distributed_rank = start_rank + i main(args, init_distributed=True) @@ -415,9 +297,6 @@ def print_options_meaning_changes(args): def cli_main(): parser = options.get_training_parser(default_task='speech_recognition') - parser.add_argument('--disable-cudnn', action='store_true', - help='disable cudnn, which would make the training ' - 'much slower') args = options.parse_args_and_arch(parser) print_options_meaning_changes(args) @@ -426,9 +305,19 @@ def cli_main(): if args.distributed_init_method is not None: # distributed training - distributed_main(args.device_id, args) + if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: + start_rank = args.distributed_rank + args.distributed_rank = None # assign automatically + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args, start_rank), + nprocs=torch.cuda.device_count(), + ) + else: + distributed_main(args.device_id, args) elif args.distributed_world_size > 1: # fallback for single node with multiple GPUs + assert args.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_rank = None # set based on device id From cd8803dc222da011655d186b7b33d915a6947833 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 26 May 2019 01:50:15 -0400 Subject: [PATCH 019/119] add scheduled sampling training support --- examples/asr_wsj/run.sh | 11 ++-- fairseq/criterions/cross_entropy_with_wer.py | 53 +++++++++++++++++-- .../label_smoothed_cross_entropy_with_wer.py | 53 +++++++++++++++++-- speech_train.py | 2 + 4 files changed, 104 insertions(+), 15 deletions(-) diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 177d28077..63a38537f 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -11,7 +11,7 @@ set -e -o pipefail stage=0 ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu -free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned on CLSP grid +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid # E2E model related affix= @@ -53,8 +53,6 @@ wordlmdir=exp/wordlm_lstm${wordlm_affix:+_${wordlm_affix}} dir=exp/lstm${affix:+_$affix} if [ ${stage} -le 0 ]; then - ### Task dependent. You have to make data the following preparation part by yourself. - ### But you can utilize Kaldi recipes in most cases echo "Stage 0: Data Preparation" local/wsj_data_prep.sh ${wsj0}/??-{?,??}.? ${wsj1}/??-{?,??}.? echo "Preparing train and test data" @@ -75,8 +73,6 @@ train_subset_feat_dir=${dumpdir}/${train_set}_${train_subset_size}/delta${do_del valid_feat_dir=${dumpdir}/${valid_set}/delta${do_delta}; mkdir -p ${valid_feat_dir} test_feat_dir=${dumpdir}/${test_set}/delta${do_delta}; mkdir -p ${test_feat_dir} if [ ${stage} -le 1 ]; then - ### Task dependent. You have to design training and dev sets by yourself. - ### But you can utilize Kaldi recipes in most cases echo "Stage 1: Feature Generation" fbankdir=fbank # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame @@ -275,12 +271,13 @@ if [ ${stage} -le 8 ]; then --num-workers 0 --max-tokens 24000 --max-sentences 32 \ --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ - --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ + --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer \ --label-smoothing 0.05 --smoothing-type temporal \ + --scheduled-sampling-probs 0.4 --start-scheduled-sampling-epoch 11 \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ --dict $dict --non-lang-syms $nlsyms \ @@ -299,7 +296,7 @@ if [ ${stage} -le 9 ]; then decode_affix=shallow_fusion else path="$path:$wordlmdir/$lm_checkpoint" - opts="$opts --word-dict $wordlmdict --lm-weight 0.7 --oov-penalty 1e-4 --coverage-weight 0.01" + opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-6 --coverage-weight 0.01" decode_affix=shallow_fusion_wordlm fi fi diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 07412bee3..63dbb3e23 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -12,6 +12,7 @@ from fairseq import utils, wer from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder +from fairseq.options import eval_str_list from . import FairseqCriterion, register_criterion from .cross_entropy import CrossEntropyCriterion @@ -29,6 +30,7 @@ def __init__(self, args, task): self.train_tgt_dataset = task.dataset(args.train_subset).tgt self.valid_tgt_dataset = None self.num_updates = -1 + self.epoch = 0 @staticmethod def add_args(parser): @@ -38,6 +40,14 @@ def add_args(parser): metavar='N', dest='print_interval', default=500, help='print a training sample (reference + ' 'prediction) every this number of updates') + parser.add_argument('--scheduled-sampling-probs', type=eval_str_list, + metavar='P_1,P_2,...,P_N', default=1.0, + help='schedule sampling probabilities of sampling the truth ' + 'labels for N epochs starting from --start-schedule-sampling-epoch; ' + 'all later epochs using P_N') + parser.add_argument('--start-scheduled-sampling-epoch', type=int, + metavar='N', default=1, + help='start schedule sampling from the specified epoch') # fmt: on def forward(self, model, sample, reduce=True): @@ -52,9 +62,42 @@ def forward(self, model, sample, reduce=True): """ dict = self.scorer.dict if model.training: - net_output = model(**sample['net_input']) - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) + if (len(self.args.scheduled_sampling_probs) > 1 or \ + self.args.scheduled_sampling_probs[0] < 1.0) and \ + self.epoch >= self.args.start_scheduled_sampling_epoch: + # scheduled sampling + ss_prob = self.args.scheduled_sampling_probs[ + min(self.epoch - self.args.start_scheduled_sampling_epoch, + len(self.args.scheduled_sampling_probs) - 1) + ] + assert isinstance(model.decoder, FairseqIncrementalDecoder) + incremental_states = {} + encoder_input = { + k: v for k, v in sample['net_input'].items() + if k != 'prev_output_tokens' + } + encoder_out = model.encoder(**encoder_input) + target = sample['target'] + tokens = sample['net_input']['prev_output_tokens'] + lprobs = [] + for step in range(target.size(1)): + if step > 0: + sampling_mask = torch.rand([target.size(0), 1], + device=target.device).lt(ss_prob) + feed_tokens = torch.where(sampling_mask, + tokens[:, step:step + 1], pred) + else: + feed_tokens = tokens[:, step:step + 1] + log_probs, _ = self._decode(feed_tokens, + model, encoder_out, incremental_states) + pred = log_probs.argmax(-1,keepdim=True) + lprobs.append(log_probs) + lprobs = torch.stack(lprobs, dim=1) + else: + # normal training + net_output = model(**sample['net_input']) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) else: assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} @@ -66,7 +109,6 @@ def forward(self, model, sample, reduce=True): target = sample['target'] # make the maximum decoding length equal to at least the length of # target, and the length of encoder_out if possible - # and at least the length of target maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) tokens[:, 0] = dict.eos() @@ -187,3 +229,6 @@ def set_valid_tgt_dataset(self, dataset): def set_num_updates(self, num_updates): self.num_updates = num_updates + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 7b421bd96..0f4b1494c 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -11,6 +11,7 @@ from fairseq import utils, wer from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder +from fairseq.options import eval_str_list from . import FairseqCriterion, register_criterion from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion @@ -28,6 +29,7 @@ def __init__(self, args, task): self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 + self.epoch = 0 if args.smoothing_type == 'unigram': self.unigram_tensor = torch.cuda.FloatTensor(dict.count).unsqueeze(-1) \ if torch.cuda.is_available() and not args.cpu \ @@ -46,6 +48,14 @@ def add_args(parser): parser.add_argument('--smoothing-type', type=str, default='uniform', choices=['uniform', 'unigram', 'temporal'], help='label smoothing type. Default: uniform') + parser.add_argument('--scheduled-sampling-probs', type=eval_str_list, + metavar='P_1,P_2,...,P_N', default=1.0, + help='scheduled sampling probabilities of sampling the truth ' + 'labels for N epochs starting from --start-schedule-sampling-epoch; ' + 'all later epochs using P_N') + parser.add_argument('--start-scheduled-sampling-epoch', type=int, + metavar='N', default=1, + help='start scheduled sampling from the specified epoch') # fmt: on def forward(self, model, sample, reduce=True): @@ -60,9 +70,42 @@ def forward(self, model, sample, reduce=True): """ dict = self.scorer.dict if model.training: - net_output = model(**sample['net_input']) - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) + if (len(self.args.scheduled_sampling_probs) > 1 or \ + self.args.scheduled_sampling_probs[0] < 1.0) and \ + self.epoch >= self.args.start_scheduled_sampling_epoch: + # scheduled sampling + ss_prob = self.args.scheduled_sampling_probs[ + min(self.epoch - self.args.start_scheduled_sampling_epoch, + len(self.args.scheduled_sampling_probs) - 1) + ] + assert isinstance(model.decoder, FairseqIncrementalDecoder) + incremental_states = {} + encoder_input = { + k: v for k, v in sample['net_input'].items() + if k != 'prev_output_tokens' + } + encoder_out = model.encoder(**encoder_input) + target = sample['target'] + tokens = sample['net_input']['prev_output_tokens'] + lprobs = [] + for step in range(target.size(1)): + if step > 0: + sampling_mask = torch.rand([target.size(0), 1], + device=target.device).lt(ss_prob) + feed_tokens = torch.where(sampling_mask, + tokens[:, step:step + 1], pred) + else: + feed_tokens = tokens[:, step:step + 1] + log_probs, _ = self._decode(feed_tokens, + model, encoder_out, incremental_states) + pred = log_probs.argmax(-1, keepdim=True) + lprobs.append(log_probs) + lprobs = torch.stack(lprobs, dim=1) + else: + # normal training + net_output = model(**sample['net_input']) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) else: assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} @@ -74,7 +117,6 @@ def forward(self, model, sample, reduce=True): target = sample['target'] # make the maximum decoding length equal to at least the length of # target, and the length of encoder_out if possible - # and at least the length of target maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) tokens[:, 0] = dict.eos() @@ -234,3 +276,6 @@ def set_valid_tgt_dataset(self, dataset): def set_num_updates(self, num_updates): self.num_updates = num_updates + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/speech_train.py b/speech_train.py index 81eb95ed4..f51947a23 100755 --- a/speech_train.py +++ b/speech_train.py @@ -123,6 +123,8 @@ def train(args, trainer, task, epoch_itr): extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf + if callable(getattr(trainer.criterion, 'set_epoch', None)): + trainer.criterion.set_epoch(epoch_itr.epoch) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if callable(getattr(trainer.criterion, 'set_num_updates', None)): trainer.criterion.set_num_updates(trainer.get_num_updates()) From 82a11751a9ec46882dfa3eff72490d29f9f2d3dc Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 26 May 2019 19:32:35 -0400 Subject: [PATCH 020/119] add bpe support & librispeech recipe --- examples/asr_librispeech/cmd.sh | 20 ++ examples/asr_librispeech/conf/fbank.conf | 2 + examples/asr_librispeech/conf/pitch.conf | 1 + examples/asr_librispeech/local/data_prep.sh | 1 + .../local/download_and_untar.sh | 1 + examples/asr_librispeech/local/score.sh | 1 + examples/asr_librispeech/path.sh | 17 ++ examples/asr_librispeech/run.sh | 254 ++++++++++++++++++ examples/asr_librispeech/steps | 1 + examples/asr_librispeech/utils | 1 + fairseq/criterions/cross_entropy_with_wer.py | 9 +- .../label_smoothed_cross_entropy_with_wer.py | 9 +- fairseq/data/scp_dataset.py | 4 +- fairseq/data/token_dictionary.py | 5 +- fairseq/models/speech_lstm.py | 26 ++ fairseq/wer.py | 10 +- speech_recognize.py | 10 +- speech_tools/.gitignore | 1 + speech_tools/Makefile | 16 +- speech_train.py | 4 + 20 files changed, 369 insertions(+), 24 deletions(-) create mode 100644 examples/asr_librispeech/cmd.sh create mode 100644 examples/asr_librispeech/conf/fbank.conf create mode 100644 examples/asr_librispeech/conf/pitch.conf create mode 120000 examples/asr_librispeech/local/data_prep.sh create mode 120000 examples/asr_librispeech/local/download_and_untar.sh create mode 120000 examples/asr_librispeech/local/score.sh create mode 100644 examples/asr_librispeech/path.sh create mode 100755 examples/asr_librispeech/run.sh create mode 120000 examples/asr_librispeech/steps create mode 120000 examples/asr_librispeech/utils diff --git a/examples/asr_librispeech/cmd.sh b/examples/asr_librispeech/cmd.sh new file mode 100644 index 000000000..b14280b96 --- /dev/null +++ b/examples/asr_librispeech/cmd.sh @@ -0,0 +1,20 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +#export train_cmd="run.pl --mem 4G" +#export cuda_cmd="run.pl --mem 4G --gpu 1" +#export decode_cmd="run.pl --mem 4G" + +# JHU setup +export train_cmd="queue.pl --mem 4G" +export cuda_cmd="queue.pl --mem 4G --gpu 1 --config conf/gpu.conf" +export decode_cmd="queue.pl --mem 4G" diff --git a/examples/asr_librispeech/conf/fbank.conf b/examples/asr_librispeech/conf/fbank.conf new file mode 100644 index 000000000..752323586 --- /dev/null +++ b/examples/asr_librispeech/conf/fbank.conf @@ -0,0 +1,2 @@ +--sample-frequency=16000 +--num-mel-bins=80 diff --git a/examples/asr_librispeech/conf/pitch.conf b/examples/asr_librispeech/conf/pitch.conf new file mode 100644 index 000000000..e959a19d5 --- /dev/null +++ b/examples/asr_librispeech/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/examples/asr_librispeech/local/data_prep.sh b/examples/asr_librispeech/local/data_prep.sh new file mode 120000 index 000000000..3000aeaca --- /dev/null +++ b/examples/asr_librispeech/local/data_prep.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/librispeech/s5/local/data_prep.sh \ No newline at end of file diff --git a/examples/asr_librispeech/local/download_and_untar.sh b/examples/asr_librispeech/local/download_and_untar.sh new file mode 120000 index 000000000..4edb356c0 --- /dev/null +++ b/examples/asr_librispeech/local/download_and_untar.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/librispeech/s5/local/download_and_untar.sh \ No newline at end of file diff --git a/examples/asr_librispeech/local/score.sh b/examples/asr_librispeech/local/score.sh new file mode 120000 index 000000000..3a771d6c9 --- /dev/null +++ b/examples/asr_librispeech/local/score.sh @@ -0,0 +1 @@ +../../asr_wsj/local/score.sh \ No newline at end of file diff --git a/examples/asr_librispeech/path.sh b/examples/asr_librispeech/path.sh new file mode 100644 index 000000000..7a7e115b8 --- /dev/null +++ b/examples/asr_librispeech/path.sh @@ -0,0 +1,17 @@ +MAIN_ROOT=$PWD/../.. +KALDI_ROOT=$MAIN_ROOT/speech_tools/kaldi + +# BEGIN from kaldi path.sh +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C +# END + +export PATH=~/anaconda3/bin:$PATH +export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH +export PATH=$MAIN_ROOT/tools/sentencepiece/build/src:$PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH +export PYTHONUNBUFFERED=1 + diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh new file mode 100755 index 000000000..f89dc0f7f --- /dev/null +++ b/examples/asr_librispeech/run.sh @@ -0,0 +1,254 @@ +#!/bin/bash + +# Copyright (c) 2019-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +set -e -o pipefail + +stage=0 +ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid + +# E2E model related +affix= +train_set=train_960 +valid_set=dev +test_set="test_clean test_other dev_clean dev_other" +checkpoint=checkpoint_best.pt + +# LM related +lm_affix= +lm_checkpoint=checkpoint_best.pt +lm_shallow_fusion=true # no LM fusion if false +sentencepiece_vocabsize=5000 +sentencepiece_type=unigram + +# data related +dumpdir=data/dump # directory to dump full features +data= # path to where you want to put the downloaded data; need to be specified if not on CLSP grid +if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then + data=/export/a15/vpanayotov/data +fi +data_url=www.openslr.org/resources/12 +kaldi_scoring=true + +# feature configuration +do_delta=false + + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} +dir=exp/lstm${affix:+_$affix} + +if [ ${stage} -le 0 ]; then + echo "Stage 0: Data Downloading" + for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + local/download_and_untar.sh $data $data_url $part + done +fi + +if [ ${stage} -le 1 ]; then + echo "Stage 0: Data Preparation" + for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + # use underscore-separated names in data directories. + local/data_prep.sh $data/LibriSpeech/$part data/$(echo $part | sed s/-/_/g) + done +fi + +train_feat_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${train_feat_dir} +valid_feat_dir=${dumpdir}/${valid_set}/delta${do_delta}; mkdir -p ${valid_feat_dir} +if [ ${stage} -le 1 ]; then + echo "Stage 1: Feature Generation" + fbankdir=fbank + # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame + for dataset in dev_clean test_clean dev_other test_other train_clean_100 train_clean_360 train_other_500; do + steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 32 --write_utt2num_frames true \ + data/$dataset exp/make_fbank/$dataset ${fbankdir} + utils/fix_data_dir.sh data/$dataset + done + + utils/combine_data.sh --extra-files utt2num_frames data/${train_set} data/train_clean_100 data/train_clean_360 data/train_other_500 + utils/combine_data.sh --extra-files utt2num_frames data/${valid_set} data/dev_clean data/dev_other + + # compute global CMVN + compute-cmvn-stats scp:data/${train_set}/feats.scp data/${train_set}/cmvn.ark + + # dump features for training + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d ${train_feat_dir}/storage ]; then + utils/create_split_dir.pl \ + /export/b{10,11,12,13}/${USER}/fairseq-data/egs/asr_librispeech/dump/${train_set}/delta${do_delta}/storage \ + ${train_feat_dir}/storage + fi + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d ${valid_feat_dir}/storage ]; then + utils/create_split_dir.pl \ + /export/b{10,11,12,13}/${USER}/fairseq-data/egs/asr_librispeech/dump/${valid_set}/delta${do_delta}/storage \ + ${valid_feat_dir}/storage + fi + dump.sh --cmd "$train_cmd" --nj 80 --do_delta $do_delta \ + data/${train_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/train ${train_feat_dir} + dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ + data/${valid_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/valid ${valid_feat_dir} + for dataset in $test_set; do + test_feat_dir=${dumpdir}/$dataset/delta${do_delta}; mkdir -p ${test_feat_dir} + dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ + data/$dataset/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/$dataset ${test_feat_dir} + done +fi + +dict=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize}_units.txt +sentencepiece_model=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize} +lmdatadir=data/lm_text +if [ ${stage} -le 2 ]; then + echo "Stage 2: Dictionary Preparation and Text Tokenization" + mkdir -p data/lang + cut -f 2- -d" " data/${train_set}/text > data/lang/input + echo "$0: training sentencepiece model..." + spm_train --bos_id=-1 --pad_id=0 --eos_id=1 --unk_id=2 --input=data/lang/input \ + --vocab_size=$sentencepiece_vocabsize --model_type=$sentencepiece_type \ + --model_prefix=$sentencepiece_model --input_sentence_size=10000000 + echo "$0: making a dictionary and tokenizing text for train/valid/test set..." + for dataset in $train_set $valid_set $test_set; do + text=data/$dataset/text + token_text=data/$dataset/token_text + spm_encode --model=${sentencepiece_model}.model --output_format=piece \ + <(cut -f 2- -d" " $text) | paste -d" " <(cut -f 1 -d" " $text) - > $token_text + if [ "$dataset" == "$train_set" ]; then + cat $token_text | tr ' ' '\n' | sort | uniq -c | awk '{print $2,$1}' | sort > $dict + wc -l $dict + fi + done + + echo "$0: preparing text for subword LM..." + mkdir -p $lmdatadir + for dataset in $train_set $valid_set $test_set; do + token_text=data/$dataset/token_text + cut -f 2- -d" " $token_text > $lmdatadir/$dataset.tokens + done + if [ ! -e $lmdatadir/librispeech-lm-norm.txt.gz ]; then + wget http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz -P $lmdatadir + fi + zcat $lmdatadir/librispeech-lm-norm.txt.gz | \ + spm_encode --model=${sentencepiece_model}.model --output_format=piece | \ + cat $lmdatadir/$train_set.tokens - > $lmdatadir/train.tokens +fi + +lmdict=$dict +if [ ${stage} -le 3 ]; then + echo "Stage 3: Text Binarization for subword LM Training" + mkdir -p $lmdatadir/logs + for dataset in $test_set; do test_paths="$test_paths,$lmdatadir/$dataset.tokens"; done + test_paths=$lmdatadir/$(echo $test_set | sed 's/ /,/g').tokens + ${decode_cmd} $lmdatadir/logs/preprocess.log \ + python3 ../../preprocess.py --task language_modeling_for_asr \ + --workers 50 --srcdict $lmdict --only-source \ + --trainpref $lmdatadir/train.tokens \ + --validpref $lmdatadir/$valid_set.tokens \ + --testpref $test_paths \ + --destdir $lmdatadir +fi + +[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) +[ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; +[ $(echo $free_gpu | sed 's/,/ /g' | awk '{print NF}') -ne "$ngpus" ] && \ + echo "number of GPU ids in --free-gpu=$free_gpu does not match --ngpus=$ngpus" && exit 1; + +if [ ${stage} -le 4 ]; then + echo "Stage 4: subword LM Training" + valid_subset=valid + mkdir -p $lmdir/logs + log_file=$lmdir/logs/train.log + [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ + --task language_modeling_for_asr --dict $lmdict \ + --log-interval 2000 --log-format simple \ + --num-workers 0 --max-tokens 20480 --max-sentences 256 \ + --valid-subset $valid_subset --max-sentences-valid 512 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ + --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ + --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ + --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --arch lstm_lm_librispeech --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file +fi + +if [ ${stage} -le 5 ]; then + echo "Stage 5: subword LM Evaluation" + num=$(echo $test_set | awk '{print NF-1}') + gen_set=test + for i in $(seq $num); do gen_set="$gen_set test$i"; done + for gen_subset in $gen_set; do + log_file=$lmdir/logs/evaluation_$gen_subset.log + python3 ../../eval_lm.py $lmdatadir --cpu \ + --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ + --max-tokens 40960 --max-sentences 512 --sample-break-mode eos \ + --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file + done +fi + +train_feat=$train_feat_dir/feats.scp +train_token_text=data/$train_set/token_text +valid_feat=$valid_feat_dir/feats.scp +valid_token_text=data/$valid_set/token_text +if [ ${stage} -le 6 ]; then + echo "Stage 6: Model Training" + opts="" + valid_subset=valid + [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" + mkdir -p $dir/logs + log_file=$dir/logs/train.log + [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ + --log-interval 1000 --log-format simple --print-training-sample-interval 1000 \ + --num-workers 0 --max-tokens 24000 --max-sentences 32 \ + --valid-subset $valid_subset --max-sentences-valid 64 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ + --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ + --label-smoothing 0.05 --smoothing-type temporal \ + --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ + --train-feat-files $train_feat --train-text-files $train_token_text \ + --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ + --dict $dict \ + --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file +fi + +if [ ${stage} -le 7 ]; then + echo "Stage 7: Decoding" + opts="" + path=$dir/$checkpoint + decode_affix= + if $lm_shallow_fusion; then + path="$path:$lmdir/$lm_checkpoint" + opts="$opts --lm-weight 0.7 --coverage-weight 0.01" + decode_affix=shallow_fusion + fi + [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" + for dataset in $test_set; do + feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp + text=data/$dataset/token_text + CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ + --max-tokens 20000 --max-sentences 32 --num-shards 1 --shard-id 0 \ + --test-feat-files $feat --test-text-files $text \ + --dict $dict --remove-bpe sentencepiece \ + --max-source-positions 9999 --max-target-positions 999 \ + --path $path --beam 50 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ + --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ + --print-alignment 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + + if $kaldi_scoring; then + echo "verify WER by scoring with Kaldi..." + local/score.sh data/$dataset $dir/decode_$dataset${decode_affix:+_${decode_affix}} + cat $dir/decode_$dataset${decode_affix:+_${decode_affix}}/scoring_kaldi/wer + fi + done +fi diff --git a/examples/asr_librispeech/steps b/examples/asr_librispeech/steps new file mode 120000 index 000000000..ec9b528ac --- /dev/null +++ b/examples/asr_librispeech/steps @@ -0,0 +1 @@ +../../speech_tools/kaldi/egs/wsj/s5/steps \ No newline at end of file diff --git a/examples/asr_librispeech/utils b/examples/asr_librispeech/utils new file mode 120000 index 000000000..ea44d93b9 --- /dev/null +++ b/examples/asr_librispeech/utils @@ -0,0 +1 @@ +../../speech_tools/kaldi/egs/wsj/s5/utils \ No newline at end of file diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 63dbb3e23..f44e24ec8 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -41,7 +41,7 @@ def add_args(parser): help='print a training sample (reference + ' 'prediction) every this number of updates') parser.add_argument('--scheduled-sampling-probs', type=eval_str_list, - metavar='P_1,P_2,...,P_N', default=1.0, + metavar='P_1,P_2,...,P_N', default=[1.0], help='schedule sampling probabilities of sampling the truth ' 'labels for N epochs starting from --start-schedule-sampling-epoch; ' 'all later epochs using P_N') @@ -170,9 +170,12 @@ def forward(self, model, sample, reduce=True): id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) - ref_one = self.train_tgt_dataset.get_original_text(id, dict) + ref_one = self.train_tgt_dataset.get_original_text(id, dict, + bpe_symbol=self.args.remove_bpe) pred_one = dict.tokens_to_sentence( - dict.string(pred.data[i][:length])) + dict.string(pred.data[i][:length]), + bpe_symbol=self.args.remove_bpe, + ) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 0f4b1494c..5a26808b1 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -49,7 +49,7 @@ def add_args(parser): choices=['uniform', 'unigram', 'temporal'], help='label smoothing type. Default: uniform') parser.add_argument('--scheduled-sampling-probs', type=eval_str_list, - metavar='P_1,P_2,...,P_N', default=1.0, + metavar='P_1,P_2,...,P_N', default=[1.0], help='scheduled sampling probabilities of sampling the truth ' 'labels for N epochs starting from --start-schedule-sampling-epoch; ' 'all later epochs using P_N') @@ -178,9 +178,12 @@ def forward(self, model, sample, reduce=True): id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) - ref_one = self.train_tgt_dataset.get_original_text(id, dict) + ref_one = self.train_tgt_dataset.get_original_text(id, dict, + bpe_symbol=self.args.remove_bpe) pred_one = dict.tokens_to_sentence( - dict.string(pred.data[i][:length])) + dict.string(pred.data[i][:length]), + bpe_symbol=self.args.remove_bpe, + ) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index c8e4d3ee6..a6fc3e444 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -223,10 +223,10 @@ def get_original_tokens(self, i): self.check_index(i) return self.tokens_list[i] - def get_original_text(self, i, dictionary): + def get_original_text(self, i, dictionary, bpe_symbol=None): self.check_index(i) return dictionary.tokens_to_sentence(self.tokens_list[i], - use_unk_sym=False) + use_unk_sym=False, bpe_symbol=bpe_symbol) def __len__(self): return self.size diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index ba409180b..3f2437c50 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -98,7 +98,10 @@ def dummy_sentence(self, length): t[-1] = self.eos() return t - def tokens_to_sentence(self, line, line_tokenizer=tokenize_line, use_unk_sym=True): + def tokens_to_sentence(self, line, line_tokenizer=tokenize_line, + use_unk_sym=True, bpe_symbol=None): + if bpe_symbol is not None: + return data_utils.process_bpe_symbol(sent, bpe_symbol) # use_unk_sym=False when we want to restore original transcripts from # token sequences, e.g., obtain reference to compute WER tokens = line_tokenizer(line) diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index ac2aba446..acb9cfd81 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -773,6 +773,17 @@ def lstm_lm_wsj(args): base_lm_architecture(args) +@register_model_architecture('lstm_lm', 'lstm_lm_librispeech') +def lstm_lm_librispeech(args): + args.dropout = getattr(args, 'dropout', 0.2) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) + args.decoder_layers = getattr(args, 'decoder_layers', 1) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1024) + args.share_embed = getattr(args, 'share_embed', True) + base_lm_architecture(args) + + @register_model_architecture('lstm_lm', 'lstm_wordlm_wsj') def lstm_wordlm_wsj(args): args.dropout = getattr(args, 'dropout', 0.3) @@ -816,6 +827,21 @@ def base_architecture(args): args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.pretrained_lm_checkpoint = getattr(args, 'pretrained_lm_checkpoint', None) + @register_model_architecture('speech_lstm', 'speech_conv_lstm_wsj') def conv_lstm_wsj(args): base_architecture(args) + + +@register_model_architecture('speech_lstm', 'speech_conv_lstm_librispeech') +def speech_conv_lstm_librispeech(args): + args.dropout = getattr(args, 'dropout', 0.3) + args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 1024) + args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) + args.decoder_layers = getattr(args, 'decoder_layers', 3) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 3072) + args.attention_type = getattr(args, 'attention_type', 'bahdanau') + args.attention_dim = getattr(args, 'attention_dim', 1024) + base_architecture(args) diff --git a/fairseq/wer.py b/fairseq/wer.py index abc60973c..e3a89a1d2 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -46,7 +46,7 @@ def parse_wer_output_filter(self, wer_output_filter): print('Unsupported pattern: "{}", ignored'.format(line), file=sys.stderr) - def add_prediction(self, utt_id, pred): + def add_prediction(self, utt_id, pred, bpe_symbol=None): if not isinstance(utt_id, str): raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) if not isinstance(pred, str): @@ -56,12 +56,12 @@ def add_prediction(self, utt_id, pred): 'Duplicated utterance id detected: {}'.format(utt_id) self.char_results[utt_id] = pred + '\n' - pred_words = self.dict.tokens_to_sentence(pred) + pred_words = self.dict.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) assert not utt_id in self.results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.results[utt_id] = pred_words + '\n' - def add_evaluation(self, utt_id, ref, pred): + def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): if not isinstance(utt_id, str): raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) if not isinstance(ref, str): @@ -83,8 +83,8 @@ def add_evaluation(self, utt_id, ref, pred): self.char_counter += counter # word level counts - ref_words = self.dict.tokens_to_sentence(ref, use_unk_sym=False) - pred_words = self.dict.tokens_to_sentence(pred) + ref_words = self.dict.tokens_to_sentence(ref, use_unk_sym=False, bpe_symbol=bpe_symbol) + pred_words = self.dict.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) # filter words according to self.word_filters (support re.sub only) for pattern, repl in self.word_filters: diff --git a/speech_recognize.py b/speech_recognize.py index b617b6050..3614cba6a 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -135,14 +135,14 @@ def main(args): target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) if not args.quiet: target_sent = dict.tokens_to_sentence(target_str, - use_unk_sym=False) + use_unk_sym=False, bpe_symbol=args.remove_bpe) print('T-{}\t{}'.format(utt_id, target_sent)) # Process top predictions for i, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]): - hypo_str = dict.string(hypo['tokens'].int().cpu(), args.remove_bpe) + hypo_str = dict.string(hypo['tokens'].int().cpu()) # not removing bpe at this point if not args.quiet or i == 0: - hypo_sent = dict.tokens_to_sentence(hypo_str) + hypo_sent = dict.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) if not args.quiet: print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) @@ -156,9 +156,9 @@ def main(args): save_dir = os.path.join(args.results_path, 'attn_plots') os.makedirs(save_dir, exist_ok=True) plot_attention(attention, hypo_sent, utt_id, save_dir) - scorer.add_prediction(utt_id, hypo_str) + scorer.add_prediction(utt_id, hypo_str, bpe_symbol=args.remove_bpe) if has_target: - scorer.add_evaluation(utt_id, target_str, hypo_str) + scorer.add_evaluation(utt_id, target_str, hypo_str, bpe_symbol=args.remove_bpe) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) diff --git a/speech_tools/.gitignore b/speech_tools/.gitignore index 77acdc259..b86428794 100644 --- a/speech_tools/.gitignore +++ b/speech_tools/.gitignore @@ -1,3 +1,4 @@ kaldi kaldi-io-for-python kaldi_io.py +sentencepiece diff --git a/speech_tools/Makefile b/speech_tools/Makefile index 5fe67d1c0..28b5c2689 100644 --- a/speech_tools/Makefile +++ b/speech_tools/Makefile @@ -2,21 +2,27 @@ KALDI = .PHONY: all clean -all: kaldi kaldi-io-for-python +all: kaldi kaldi-io-for-python sentencepiece kaldi-io-for-python: + rm -rf kaldi-io-for-python git clone https://github.com/vesis84/kaldi-io-for-python.git ln -sf kaldi-io-for-python/kaldi_io/kaldi_io.py kaldi_io.py +sentencepiece: + rm -rf sentencepiece + git clone https://github.com/google/sentencepiece.git + cd sentencepiece && mkdir build && cd build && (cmake3 .. || cmake ..) && $(MAKE) + ifneq ($(strip $(KALDI)),) kaldi: ln -s $(KALDI) kaldi else kaldi: - # git clone https://github.com/kaldi-asr/kaldi.git kaldi_github; cd kaldi_github/tools; $(MAKE) all - # cd kaldi_github/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all - # ln -nfs kaldi_github kaldi + test -d kaldi || git clone https://github.com/kaldi-asr/kaldi.git + cd kaldi/tools; $(MAKE) all + cd kaldi/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all endif clean: - rm -fr kaldi kaldi-io-for-python kaldi_io.py + rm -rf kaldi kaldi-io-for-python kaldi_io.py sentencepiece diff --git a/speech_train.py b/speech_train.py index f51947a23..83bcd121a 100755 --- a/speech_train.py +++ b/speech_train.py @@ -299,6 +299,10 @@ def print_options_meaning_changes(args): def cli_main(): parser = options.get_training_parser(default_task='speech_recognition') + parser.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, + help='remove BPE tokens before scoring ' + '(can be set to sentencepiece). Being used for monitoring ' + 'and validation') args = options.parse_args_and_arch(parser) print_options_meaning_changes(args) From 3ea6bb9023d37ea3bd611c8b8e6cd55b5e660415 Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 27 May 2019 01:35:47 -0400 Subject: [PATCH 021/119] pure batch decoding with LookAhead Word LM --- fairseq/models/external_language_model.py | 44 ++++++++++++++++------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/fairseq/models/external_language_model.py b/fairseq/models/external_language_model.py index 4c38e36ed..cd807b21d 100644 --- a/fairseq/models/external_language_model.py +++ b/fairseq/models/external_language_model.py @@ -59,11 +59,13 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): word_dict = self.lm_decoder.dictionary assert isinstance(word_dict, TokenDictionary) + self.word_pad_idx = word_dict.pad() self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() assert isinstance(subword_dict, TokenDictionary) self.subword_space_idx = subword_dict.space() + self.subword_pad_idx = subword_dict.pad() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) @@ -71,6 +73,17 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) + def max_out_degree(node): + if len(node.children) == 0: + return 0 + cur_max = len(node.children) + for _, node in node.children.items(): + cur_max = max(cur_max, max_out_degree(node)) + return cur_max + + self.max_num_children = max_out_degree(self.lexroot) + assert self.max_num_children <= self.subword_vocab_size + @torch.no_grad() def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): assert incremental_state is not None, \ @@ -172,21 +185,28 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): ).squeeze(-1) # compute transition probabilities to child nodes (case 2 in Eqn. 15) + subword_idx = [[self.subword_pad_idx] * self.max_num_children \ + for _ in range(bsz)] + left_ranges = [[self.word_pad_idx] * self.max_num_children \ + for _ in range(bsz)] + right_ranges = [[self.word_pad_idx] * self.max_num_children \ + for _ in range(bsz)] for i in range(bsz): node = nodes[i] if node is not None and len(node.children) > 0: - subword_idx, left_ranges, right_ranges = [], [], [] - for sidx, child in node.children.items(): - subword_idx.append(sidx) - left_ranges.append(child.word_set[0]) - right_ranges.append(child.word_set[1]) - subword_idx = prev_output_tokens.new(subword_idx) - left_ranges = prev_output_tokens.new(left_ranges) - right_ranges = prev_output_tokens.new(right_ranges) - out_probs[i, :, subword_idx] = \ - self.zero if sum_probs[i].item() < self.zero else \ - (cumsum_probs[i, :, right_ranges] - \ - cumsum_probs[i, :, left_ranges]) / sum_probs[i] + for j, (sidx, child) in enumerate(node.children.items()): + subword_idx[i][j] = sidx + left_ranges[i][j] = child.word_set[0] + right_ranges[i][j] = child.word_set[1] + # B x 1 x max_num_children + subword_idx = prev_output_tokens.new(subword_idx).unsqueeze(1) + left_ranges = prev_output_tokens.new(left_ranges).unsqueeze(1) + right_ranges = prev_output_tokens.new(right_ranges).unsqueeze(1) + cumsum_probs_children = (cumsum_probs.gather(-1, right_ranges) - \ + cumsum_probs.gather(-1, left_ranges)) / sum_probs.unsqueeze(-1) + cumsum_probs_children[sum_probs.squeeze(-1) < self.zero, :, :] = self.zero + out_probs.scatter_(-1, subword_idx, cumsum_probs_children) + out_probs[:, :, self.subword_pad_idx] = self.zero # apply word-level probabilies for and (case 1 in Eqn. 15) word_idx, batch_node_word_end_mask = [], [] From 36baed5c0cfbfdfd6b3add78ccddc806a443af6d Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 31 May 2019 02:27:57 -0400 Subject: [PATCH 022/119] librispeech recipe fix; code adaptation/changes according to the commits on June 11, 2019 --- examples/asr_librispeech/path.sh | 2 +- examples/asr_librispeech/run.sh | 106 +++++++++--------- examples/asr_wsj/run.sh | 7 +- fairseq/criterions/cross_entropy_with_wer.py | 20 +++- .../label_smoothed_cross_entropy_with_wer.py | 7 +- fairseq/data/token_dictionary.py | 2 +- fairseq/models/speech_lstm.py | 56 ++++----- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 27 ++++- fairseq/tasks/speech_recognition.py | 32 +++--- speech_recognize.py | 4 +- speech_tools/Makefile | 2 +- speech_train.py | 6 +- 12 files changed, 148 insertions(+), 123 deletions(-) diff --git a/examples/asr_librispeech/path.sh b/examples/asr_librispeech/path.sh index 7a7e115b8..3290c7576 100644 --- a/examples/asr_librispeech/path.sh +++ b/examples/asr_librispeech/path.sh @@ -11,7 +11,7 @@ export LC_ALL=C export PATH=~/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH -export PATH=$MAIN_ROOT/tools/sentencepiece/build/src:$PATH +export PATH=$MAIN_ROOT/speech_tools/sentencepiece/build/src:$PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index f89dc0f7f..6ae6fa6c8 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -55,7 +55,7 @@ if [ ${stage} -le 0 ]; then fi if [ ${stage} -le 1 ]; then - echo "Stage 0: Data Preparation" + echo "Stage 1: Data Preparation" for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do # use underscore-separated names in data directories. local/data_prep.sh $data/LibriSpeech/$part data/$(echo $part | sed s/-/_/g) @@ -64,8 +64,8 @@ fi train_feat_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${train_feat_dir} valid_feat_dir=${dumpdir}/${valid_set}/delta${do_delta}; mkdir -p ${valid_feat_dir} -if [ ${stage} -le 1 ]; then - echo "Stage 1: Feature Generation" +if [ ${stage} -le 2 ]; then + echo "Stage 2: Feature Generation" fbankdir=fbank # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame for dataset in dev_clean test_clean dev_other test_other train_clean_100 train_clean_360 train_other_500; do @@ -83,12 +83,12 @@ if [ ${stage} -le 1 ]; then # dump features for training if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d ${train_feat_dir}/storage ]; then utils/create_split_dir.pl \ - /export/b{10,11,12,13}/${USER}/fairseq-data/egs/asr_librispeech/dump/${train_set}/delta${do_delta}/storage \ + /export/b1{4,5,6,7}/${USER}/fairseq-data/egs/asr_librispeech/dump/${train_set}/delta${do_delta}/storage \ ${train_feat_dir}/storage fi if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d ${valid_feat_dir}/storage ]; then utils/create_split_dir.pl \ - /export/b{10,11,12,13}/${USER}/fairseq-data/egs/asr_librispeech/dump/${valid_set}/delta${do_delta}/storage \ + /export/b1{4,5,6,7}/${USER}/fairseq-data/egs/asr_librispeech/dump/${valid_set}/delta${do_delta}/storage \ ${valid_feat_dir}/storage fi dump.sh --cmd "$train_cmd" --nj 80 --do_delta $do_delta \ @@ -105,14 +105,15 @@ fi dict=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize}_units.txt sentencepiece_model=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize} lmdatadir=data/lm_text -if [ ${stage} -le 2 ]; then - echo "Stage 2: Dictionary Preparation and Text Tokenization" +if [ ${stage} -le 3 ]; then + echo "Stage 3: Dictionary Preparation and Text Tokenization" mkdir -p data/lang cut -f 2- -d" " data/${train_set}/text > data/lang/input echo "$0: training sentencepiece model..." spm_train --bos_id=-1 --pad_id=0 --eos_id=1 --unk_id=2 --input=data/lang/input \ - --vocab_size=$sentencepiece_vocabsize --model_type=$sentencepiece_type \ - --model_prefix=$sentencepiece_model --input_sentence_size=10000000 + --vocab_size=$((sentencepiece_vocabsize+3)) --character_coverage=1.0 \ + --model_type=$sentencepiece_type --model_prefix=$sentencepiece_model \ + --input_sentence_size=10000000 echo "$0: making a dictionary and tokenizing text for train/valid/test set..." for dataset in $train_set $valid_set $test_set; do text=data/$dataset/text @@ -120,7 +121,8 @@ if [ ${stage} -le 2 ]; then spm_encode --model=${sentencepiece_model}.model --output_format=piece \ <(cut -f 2- -d" " $text) | paste -d" " <(cut -f 1 -d" " $text) - > $token_text if [ "$dataset" == "$train_set" ]; then - cat $token_text | tr ' ' '\n' | sort | uniq -c | awk '{print $2,$1}' | sort > $dict + cut -f 2- -d" " $token_text | tr ' ' '\n' | sort | uniq -c | \ + awk '{print $2,$1}' | sort > $dict wc -l $dict fi done @@ -134,17 +136,18 @@ if [ ${stage} -le 2 ]; then if [ ! -e $lmdatadir/librispeech-lm-norm.txt.gz ]; then wget http://www.openslr.org/resources/11/librispeech-lm-norm.txt.gz -P $lmdatadir fi + echo "$0: preparing extra corpus for subword LM training..." zcat $lmdatadir/librispeech-lm-norm.txt.gz | \ spm_encode --model=${sentencepiece_model}.model --output_format=piece | \ cat $lmdatadir/$train_set.tokens - > $lmdatadir/train.tokens fi lmdict=$dict -if [ ${stage} -le 3 ]; then - echo "Stage 3: Text Binarization for subword LM Training" +if [ ${stage} -le 4 ]; then + echo "Stage 4: Text Binarization for subword LM Training" mkdir -p $lmdatadir/logs - for dataset in $test_set; do test_paths="$test_paths,$lmdatadir/$dataset.tokens"; done - test_paths=$lmdatadir/$(echo $test_set | sed 's/ /,/g').tokens + for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done + test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/logs/preprocess.log \ python3 ../../preprocess.py --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ @@ -154,40 +157,42 @@ if [ ${stage} -le 3 ]; then --destdir $lmdatadir fi -[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) +[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) || \ + echo "Unable to get $ngpus GPUs" [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; [ $(echo $free_gpu | sed 's/,/ /g' | awk '{print NF}') -ne "$ngpus" ] && \ echo "number of GPU ids in --free-gpu=$free_gpu does not match --ngpus=$ngpus" && exit 1; -if [ ${stage} -le 4 ]; then - echo "Stage 4: subword LM Training" +if [ ${stage} -le 5 ]; then + echo "Stage 5: subword LM Training" valid_subset=valid mkdir -p $lmdir/logs log_file=$lmdir/logs/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ - --log-interval 2000 --log-format simple \ - --num-workers 0 --max-tokens 20480 --max-sentences 256 \ - --valid-subset $valid_subset --max-sentences-valid 512 \ + --log-interval 8000 --log-format simple \ + --num-workers 0 --max-tokens 30720 --max-sentences 1024 \ + --valid-subset $valid_subset --max-sentences-valid 1536 \ --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ - --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ + --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ - --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ - --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 8000 \ + --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ --arch lstm_lm_librispeech --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file fi -if [ ${stage} -le 5 ]; then - echo "Stage 5: subword LM Evaluation" +if [ ${stage} -le 6 ]; then + echo "Stage 6: subword LM Evaluation" + gen_set_array=(test) num=$(echo $test_set | awk '{print NF-1}') - gen_set=test - for i in $(seq $num); do gen_set="$gen_set test$i"; done - for gen_subset in $gen_set; do - log_file=$lmdir/logs/evaluation_$gen_subset.log + for i in $(seq $num); do gen_set_array[$i]="test$i"; done + test_set_array=($test_set) + for i in $(seq 0 $num); do + log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log python3 ../../eval_lm.py $lmdatadir --cpu \ - --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ - --max-tokens 40960 --max-sentences 512 --sample-break-mode eos \ + --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ + --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file done fi @@ -196,54 +201,51 @@ train_feat=$train_feat_dir/feats.scp train_token_text=data/$train_set/token_text valid_feat=$valid_feat_dir/feats.scp valid_token_text=data/$valid_set/token_text -if [ ${stage} -le 6 ]; then - echo "Stage 6: Model Training" - opts="" +if [ ${stage} -le 7 ]; then + echo "Stage 7: Model Training" valid_subset=valid - [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ - --log-interval 1000 --log-format simple --print-training-sample-interval 1000 \ - --num-workers 0 --max-tokens 24000 --max-sentences 32 \ - --valid-subset $valid_subset --max-sentences-valid 64 \ + --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ + --num-workers 0 --max-tokens 26000 --max-sentences 24 \ + --valid-subset $valid_subset --max-sentences-valid 48 \ --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ - --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ - --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ + --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ - --label-smoothing 0.05 --smoothing-type temporal \ + --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ - --dict $dict \ - --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file + --dict $dict --remove-bpe sentencepiece \ + --max-source-positions 9999 --max-target-positions 999 2>&1 | tee $log_file fi -if [ ${stage} -le 7 ]; then - echo "Stage 7: Decoding" +if [ ${stage} -le 8 ]; then + echo "Stage 8: Decoding" opts="" path=$dir/$checkpoint decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.7 --coverage-weight 0.01" + opts="$opts --lm-weight 0.35 --coverage-weight 0.0" decode_affix=shallow_fusion fi - [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $test_set; do feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --max-tokens 20000 --max-sentences 32 --num-shards 1 --shard-id 0 \ + --max-tokens 16000 --max-sentences 24 --num-shards 1 --shard-id 0 \ --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ - --path $path --beam 50 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ + --path $path --beam 25 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ - --print-alignment 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log if $kaldi_scoring; then echo "verify WER by scoring with Kaldi..." diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 63a38537f..3f07108d7 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -183,7 +183,8 @@ if [ ${stage} -le 3 ]; then fi fi -[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) +[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) || \ + echo "Unable to get $ngpus GPUs" [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; [ $(echo $free_gpu | sed 's/,/ /g' | awk '{print NF}') -ne "$ngpus" ] && \ echo "number of GPU ids in --free-gpu=$free_gpu does not match --ngpus=$ngpus" && exit 1; @@ -267,7 +268,7 @@ if [ ${stage} -le 8 ]; then log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ - --log-interval 1000 --log-format simple --print-training-sample-interval 1000 \ + --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 \ --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ @@ -296,7 +297,7 @@ if [ ${stage} -le 9 ]; then decode_affix=shallow_fusion else path="$path:$wordlmdir/$lm_checkpoint" - opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-6 --coverage-weight 0.01" + opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 5e-7 --coverage-weight 0.01" decode_affix=shallow_fusion_wordlm fi fi diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index f44e24ec8..b30784fd3 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -27,7 +27,7 @@ def __init__(self, args, task): dict = task.target_dictionary self.scorer = wer.Scorer(dict, wer_output_filter=task.args.wer_output_filter) - self.train_tgt_dataset = task.dataset(args.train_subset).tgt + self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 self.epoch = 0 @@ -40,8 +40,8 @@ def add_args(parser): metavar='N', dest='print_interval', default=500, help='print a training sample (reference + ' 'prediction) every this number of updates') - parser.add_argument('--scheduled-sampling-probs', type=eval_str_list, - metavar='P_1,P_2,...,P_N', default=[1.0], + parser.add_argument('--scheduled-sampling-probs', type=lambda p: eval_str_list(p), + metavar='P_1,P_2,...,P_N', default=1.0, help='schedule sampling probabilities of sampling the truth ' 'labels for N epochs starting from --start-schedule-sampling-epoch; ' 'all later epochs using P_N') @@ -162,7 +162,8 @@ def forward(self, model, sample, reduce=True): if id < len( self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) pred_tokens = dict.string(pred.data[i]) - self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) + self.scorer.add_evaluation(utt_id, ref_tokens, + pred_tokens, bpe_symbol=self.args.remove_bpe) else: # print a randomly sampled result every print_interval updates assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): @@ -180,8 +181,12 @@ def forward(self, model, sample, reduce=True): print('| sample PRD: ' + pred_one) # word error stats code ends lprobs = lprobs.view(-1, lprobs.size(-1)) - loss = F.nll_loss(lprobs, target.view(-1), ignore_index=self.padding_idx, - reduction='sum' if reduce else 'none') + loss = F.nll_loss( + lprobs, + target.view(-1), + ignore_index=self.padding_idx, + reduction='sum' if reduce else 'none', + ) sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, @@ -227,6 +232,9 @@ def _decode(self, tokens, model, encoder_out, incremental_states): probs = probs[:, -1, :] return probs, attn + def set_train_tgt_dataset(self, dataset): + self.train_tgt_dataset = dataset + def set_valid_tgt_dataset(self, dataset): self.valid_tgt_dataset = dataset diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 5a26808b1..1d4c18c6a 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -48,8 +48,8 @@ def add_args(parser): parser.add_argument('--smoothing-type', type=str, default='uniform', choices=['uniform', 'unigram', 'temporal'], help='label smoothing type. Default: uniform') - parser.add_argument('--scheduled-sampling-probs', type=eval_str_list, - metavar='P_1,P_2,...,P_N', default=[1.0], + parser.add_argument('--scheduled-sampling-probs', type=lambda p: eval_str_list(p), + metavar='P_1,P_2,...,P_N', default=1.0, help='scheduled sampling probabilities of sampling the truth ' 'labels for N epochs starting from --start-schedule-sampling-epoch; ' 'all later epochs using P_N') @@ -170,7 +170,8 @@ def forward(self, model, sample, reduce=True): if id < len( self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) pred_tokens = dict.string(pred.data[i]) - self.scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) + self.scorer.add_evaluation(utt_id, ref_tokens, + pred_tokens, bpe_symbol=self.args.remove_bpe) else: # print a randomly sampled result every print_interval updates assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 3f2437c50..1060faee2 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -101,7 +101,7 @@ def dummy_sentence(self, length): def tokens_to_sentence(self, line, line_tokenizer=tokenize_line, use_unk_sym=True, bpe_symbol=None): if bpe_symbol is not None: - return data_utils.process_bpe_symbol(sent, bpe_symbol) + return data_utils.process_bpe_symbol(line, bpe_symbol) # use_unk_sym=False when we want to restore original transcripts from # token sequences, e.g., obtain reference to compute WER tokens = line_tokenizer(line) diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index acb9cfd81..b2f9a2bf0 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -49,9 +49,11 @@ def add_args(parser): help='encoder rnn\'s hidden size') parser.add_argument('--encoder-rnn-layers', type=int, metavar='N', help='number of rnn encoder layers') - parser.add_argument('--encoder-rnn-bidirectional', action='store_true', + parser.add_argument('--encoder-rnn-bidirectional', + type=lambda x: options.eval_bool(x), help='make all rnn layers of encoder bidirectional') - parser.add_argument('--encoder-rnn-residual', action='store_true', + parser.add_argument('--encoder-rnn-residual', + type=lambda x: options.eval_bool(x), help='create residual connections for rnn encoder ' 'layers (starting from the 2nd layer), i.e., the actual ' 'output of such layer is the sum of its input and output') @@ -67,7 +69,8 @@ def add_args(parser): help='number of decoder layers') parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', help='decoder output embedding dimension') - parser.add_argument('--decoder-rnn-residual', action='store_true', + parser.add_argument('--decoder-rnn-residual', + type=lambda x: options.eval_bool(x), help='create residual connections for rnn decoder ' 'layers (starting from the 2nd layer), i.e., the actual ' 'output of such layer is the sum of its input and output') @@ -81,7 +84,8 @@ def add_args(parser): parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion') - parser.add_argument('--share-decoder-input-output-embed', action='store_true', + parser.add_argument('--share-decoder-input-output-embed', + type=lambda x: options.eval_bool(x), help='share decoder input and output embeddings') parser.add_argument('--pretrained-lm-checkpoint', type=str, metavar='STR', help='path to load checkpoint from pretrained language model(LM), ' @@ -253,7 +257,8 @@ def add_args(parser): parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion') - parser.add_argument('--share-embed', action='store_true', + parser.add_argument('--share-embed', + type=lambda x: options.eval_bool(x), help='share input and output embeddings') parser.add_argument('--is-wordlm', action='store_true', help='whether it is word LM or subword LM. Only ' @@ -434,9 +439,8 @@ def forward(self, src_tokens, src_lengths): # B x T x C -> T x B x C x = x.transpose(0, 1) - state_size = (2 if self.bidirectional else 1) * self.num_layers, bsz, self.hidden_size + state_size = 2 if self.bidirectional else 1, bsz, self.hidden_size h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size) - final_hiddens, final_cells = x.new_empty(*state_size), x.new_empty(*state_size) for i in range(len(self.lstm)): if self.residual and i > 0: # residual connection starts from the 2nd layer @@ -445,11 +449,7 @@ def forward(self, src_tokens, src_lengths): packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) # apply LSTM - h0_i = h0[i * 2 : (i + 1) * 2] - c0_i = c0[i * 2 : (i + 1) * 2] - final_hiddens_i = final_hiddens[i * 2 : (i + 1) * 2] - final_cells_i = final_cells[i * 2 : (i + 1) * 2] - packed_outs, (final_hiddens_i, final_cells_i) = self.lstm[i](packed_x, (h0_i, c0_i)) + packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0)) # unpack outputs and apply dropout x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value) @@ -458,19 +458,10 @@ def forward(self, src_tokens, src_lengths): x = x + prev_x if self.residual and i > 0 else x assert list(x.size()) == [seqlen, bsz, self.output_units] - if self.bidirectional: - - def combine_bidir(outs): - out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() - return out.view(self.num_layers, bsz, -1) - - final_hiddens = combine_bidir(final_hiddens) - final_cells = combine_bidir(final_cells) - encoder_padding_mask = padding_mask.t() return { - 'encoder_out': (x, final_hiddens, final_cells), + 'encoder_out': (x,), 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None } @@ -674,8 +665,8 @@ def output_layer(self, features, **kwargs): if self.adaptive_softmax is None: # project back to size of vocabulary if hasattr(self, 'additional_fc'): - x = self.additional_fc(features) - return F.dropout(x, p=self.dropout_out, training=self.training) + features = self.additional_fc(features) + features = F.dropout(features, p=self.dropout_out, training=self.training) if self.share_input_output_embed: return F.linear(features, self.embed_tokens.weight) else: @@ -775,11 +766,11 @@ def lstm_lm_wsj(args): @register_model_architecture('lstm_lm', 'lstm_lm_librispeech') def lstm_lm_librispeech(args): - args.dropout = getattr(args, 'dropout', 0.2) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) - args.decoder_layers = getattr(args, 'decoder_layers', 1) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1024) + args.dropout = getattr(args, 'dropout', 0.0) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 360) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 720) + args.decoder_layers = getattr(args, 'decoder_layers', 4) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 360) args.share_embed = getattr(args, 'share_embed', True) base_lm_architecture(args) @@ -823,7 +814,7 @@ def base_architecture(args): args.encoder_rnn_dropout_out = getattr(args, 'encoder_rnn_dropout_out', args.dropout) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.pretrained_lm_checkpoint = getattr(args, 'pretrained_lm_checkpoint', None) @@ -838,10 +829,11 @@ def speech_conv_lstm_librispeech(args): args.dropout = getattr(args, 'dropout', 0.3) args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 1024) args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) args.decoder_layers = getattr(args, 'decoder_layers', 3) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 3072) + args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', True) args.attention_type = getattr(args, 'attention_type', 'bahdanau') - args.attention_dim = getattr(args, 'attention_dim', 1024) + args.attention_dim = getattr(args, 'attention_dim', 512) base_architecture(args) diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 14c5a11bf..c9cdabf7a 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -13,10 +13,21 @@ @register_lr_scheduler('reduce_lr_on_plateau_v2') class ReduceLROnPlateauV2(ReduceLROnPlateau): - """Decay the LR by a factor every time the validation loss plateaus, after start_epoch_to_reduce.""" + """Decay the LR by a factor every time the validation loss plateaus, starting + from the epoch specified as args.start_reduce_lr_epoch. + + We also support a warmup phase where we linearly increase the learning rate + from 0 until the configured learning rate (``--lr``). + """ def __init__(self, args, optimizer): super().__init__(args, optimizer) + + if args.warmup_updates > 0: + self.warmup_factor = 1. / args.warmup_updates + else: + self.warmup_factor = 1. + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=0, factor=args.lr_shrink, threshold=args.lr_threshold, min_lr=args.min_lr) @@ -26,12 +37,22 @@ def add_args(parser): """Add arguments to the parser for this LR scheduler.""" ReduceLROnPlateau.add_args(parser) # fmt: off + parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') parser.add_argument('--start-reduce-lr-epoch', default=0, type=int, metavar='N', - help='start to reduce lr from specified epoch') + help='start to reduce lr from the specified epoch') # fmt: on def step(self, epoch, val_loss=None): if epoch < self.args.start_reduce_lr_epoch: self.lr_scheduler.last_epoch = epoch - return self.args.lr[0] + self.optimizer.set_lr(self.warmup_factor * self.args.lr[0]) + return self.optimizer.get_lr() return super().step(epoch, val_loss) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: + self.warmup_factor = num_updates / float(self.args.warmup_updates) + self.optimizer.set_lr(self.warmup_factor * self.args.lr[0]) + return self.optimizer.get_lr() diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index c82986dc1..c5c733ca9 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -227,22 +227,22 @@ def build_generator(self, args): from fairseq.sequence_generator import SequenceGenerator return SequenceGenerator( self.target_dictionary, - beam_size=args.beam, - max_len_a=args.max_len_a, - max_len_b=args.max_len_b, - min_len=args.min_len, - stop_early=(not args.no_early_stop), - normalize_scores=(not args.unnormalized), - len_penalty=args.lenpen, - unk_penalty=args.unkpen, - sampling=args.sampling, - sampling_topk=args.sampling_topk, - temperature=args.temperature, - diverse_beam_groups=args.diverse_beam_groups, - diverse_beam_strength=args.diverse_beam_strength, - match_source_len=args.match_source_len, - no_repeat_ngram_size=args.no_repeat_ngram_size, - coverage_weight=args.coverage_weight, + beam_size=getattr(args, 'beam', 5), + max_len_a=getattr(args, 'max_len_a', 0), + max_len_b=getattr(args, 'max_len_b', 200), + min_len=getattr(args, 'min_len', 1), + stop_early=(not getattr(args, 'no_early_stop', False)), + normalize_scores=(not getattr(args, 'unnormalized', False)), + len_penalty=getattr(args, 'lenpen', 1), + unk_penalty=getattr(args, 'unkpen', 0), + sampling=getattr(args, 'sampling', False), + sampling_topk=getattr(args, 'sampling_topk', -1), + temperature=getattr(args, 'temperature', 1.), + diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1), + diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5), + match_source_len=getattr(args, 'match_source_len', False), + no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), + coverage_weight=getattr(args, 'coverage_weight', 0.0), ) def build_dataset_for_inference(self, src_tokens, src_lengths): diff --git a/speech_recognize.py b/speech_recognize.py index 3614cba6a..4d2e80896 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -74,7 +74,7 @@ def main(args): for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, - need_attn=args.print_alignment, + need_attn=args.print_alignment or args.coverage_weight > 0., ) if args.fp16: model.half() @@ -152,7 +152,7 @@ def main(args): # src_len x tgt_len attention = hypo['attention'].float().cpu() \ if hypo['attention'] is not None else None - if attention is not None: + if args.print_alignment and attention is not None: save_dir = os.path.join(args.results_path, 'attn_plots') os.makedirs(save_dir, exist_ok=True) plot_attention(attention, hypo_sent, utt_id, save_dir) diff --git a/speech_tools/Makefile b/speech_tools/Makefile index 28b5c2689..b94b4140a 100644 --- a/speech_tools/Makefile +++ b/speech_tools/Makefile @@ -12,7 +12,7 @@ kaldi-io-for-python: sentencepiece: rm -rf sentencepiece git clone https://github.com/google/sentencepiece.git - cd sentencepiece && mkdir build && cd build && (cmake3 .. || cmake ..) && $(MAKE) + cd sentencepiece && git checkout v0.1.82 && mkdir build && cd build && (cmake3 .. || cmake ..) && $(MAKE) ifneq ($(strip $(KALDI)),) kaldi: diff --git a/speech_train.py b/speech_train.py index 83bcd121a..97c74bde1 100755 --- a/speech_train.py +++ b/speech_train.py @@ -80,9 +80,9 @@ def main(args, init_distributed=False): train_meter.start() valid_losses, valid_wers = [None], [None] valid_subsets = args.valid_subset.split(',') - while lr > args.min_lr and (epoch_itr.epoch < max_epoch or \ - (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and \ - trainer.get_num_updates() < max_update: + while (lr > args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) and \ + (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch and \ + epoch_itr._next_epoch_itr is not None)) and trainer.get_num_updates() < max_update: # train for one epoch train(args, trainer, task, epoch_itr) From f4bcc7f3ba29449313d59ebdb30fae8933160b89 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 16 Jun 2019 02:34:57 -0400 Subject: [PATCH 023/119] add Transformer and FConv model for ASR --- fairseq/criterions/cross_entropy_with_wer.py | 6 +- .../label_smoothed_cross_entropy_with_wer.py | 6 +- fairseq/models/speech_fconv.py | 367 ++++++++++++++++++ fairseq/models/speech_transformer.py | 328 ++++++++++++++++ 4 files changed, 701 insertions(+), 6 deletions(-) create mode 100644 fairseq/models/speech_fconv.py create mode 100644 fairseq/models/speech_transformer.py diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index b30784fd3..399bad979 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -113,7 +113,7 @@ def forward(self, model, sample, reduce=True): tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) tokens[:, 0] = dict.eos() lprobs = [] - attn = [] if model.decoder.need_attn else None + attn = [] if getattr(model.decoder, 'need_attn', False) else None dummy_log_probs = encoder_out['encoder_out'][0].new_full( [target.size(0), len(dict)], -np.log(len(dict))) for step in range(maxlen + 1): # one extra step for EOS marker @@ -136,11 +136,11 @@ def forward(self, model, sample, reduce=True): tokens[is_eos, step + 1] = dict.eos() if step < target.size(1): lprobs.append(log_probs) - if model.decoder.need_attn: + if getattr(model.decoder, 'need_attn', False): attn.append(attn_scores) # bsz x min(tgtlen, maxlen + 1) x vocab_size lprobs = torch.stack(lprobs, dim=1) - if model.decoder.need_attn: + if getattr(model.decoder, 'need_attn', False): # bsz x (maxlen + 1) x (length of encoder_out) attn = torch.stack(attn, dim=1) # word error stats code starts diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 1d4c18c6a..5ea5fb802 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -121,7 +121,7 @@ def forward(self, model, sample, reduce=True): tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) tokens[:, 0] = dict.eos() lprobs = [] - attn = [] if model.decoder.need_attn else None + attn = [] if getattr(model.decoder, 'need_attn', False) else None dummy_log_probs = encoder_out['encoder_out'][0].new_full( [target.size(0), len(dict)], -np.log(len(dict))) for step in range(maxlen + 1): # one extra step for EOS marker @@ -144,11 +144,11 @@ def forward(self, model, sample, reduce=True): tokens[is_eos, step + 1] = dict.eos() if step < target.size(1): lprobs.append(log_probs) - if model.decoder.need_attn: + if getattr(model.decoder, 'need_attn', False): attn.append(attn_scores) # bsz x min(tgtlen, maxlen + 1) x vocab_size lprobs = torch.stack(lprobs, dim=1) - if model.decoder.need_attn: + if getattr(model.decoder, 'need_attn', False): # bsz x (maxlen + 1) x (length of encoder_out) attn = torch.stack(attn, dim=1) # word error stats code starts diff --git a/fairseq/models/speech_fconv.py b/fairseq/models/speech_fconv.py new file mode 100644 index 000000000..b8e387e92 --- /dev/null +++ b/fairseq/models/speech_fconv.py @@ -0,0 +1,367 @@ +# Copyright (c) 2019-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.models import ( + register_model, + register_model_architecture, +) +from fairseq.modules import GradMultiply + +from .speech_lstm import ConvBNReLU + +from .fconv import ( + ConvTBC, + FConvModel, + FConvEncoder, + FConvDecoder, + Linear, + extend_conv_spec, +) + +import speech_tools.utils as speech_utils + + +@register_model('speech_fconv') +class SpeechFConvModel(FConvModel): + """ + A fully convolutional model, i.e. a convolutional encoder and a + convolutional decoder, as described in `"Convolutional Sequence to Sequence + Learning" (Gehring et al., 2017) `_. + + Args: + encoder (FConvEncoder): the encoder + decoder (FConvDecoder): the decoder + + The Convolutional model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.fconv_parser + :prog: + """ + + @classmethod + def hub_models(cls): + raise NotImplementedError + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + FConvModel.add_args(parser) + parser.add_argument('--encoder-conv-channels', type=str, metavar='EXPR', + help='list of encoder convolution\'s out channels') + parser.add_argument('--encoder-conv-kernel-sizes', type=str, metavar='EXPR', + help='list of encoder convolution\'s kernel sizes') + parser.add_argument('--encoder-conv-strides', type=str, metavar='EXPR', + help='list of encoder convolution\'s strides') + parser.add_argument('--decoder-positional-embed', action='store_true', + help='use decoder positional embeddings') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + decoder_embed_dict = None + if args.decoder_embed_path: + decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path) + utils.print_embed_overlap(decoder_embed_dict, task.target_dictionary) + + def eval_str_nested_list_or_tuple(x, type=int): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + if isinstance(x, list): + return list( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + elif isinstance(x, tuple): + return tuple( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + else: + try: + return type(x) + except: + raise ValueError + + out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, + type=int) + kernel_sizes = eval_str_nested_list_or_tuple( + args.encoder_conv_kernel_sizes, type=int) + strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, + type=int) + print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, + task.feat_in_channels)) + assert task.feat_dim % task.feat_in_channels == 0 + conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, + in_channels=task.feat_in_channels) if not out_channels is None else None + + fconv_encoder_input_size = task.feat_dim // task.feat_in_channels + if conv_layers is not None: + for stride in strides: + if isinstance(stride, (list, tuple)): + assert len(stride) > 0 + s = stride[1] if len(stride) > 1 else stride[0] + else: + assert isinstance(stride, int) + s = stride + fconv_encoder_input_size = (fconv_encoder_input_size + s - 1) // s + fconv_encoder_input_size *= out_channels[-1] + + encoder = SpeechFConvEncoder( + conv_layers_before=conv_layers, + input_size=fconv_encoder_input_size, + embed_dim=args.encoder_embed_dim, + convolutions=eval(args.encoder_layers), + dropout=args.dropout, + ) + decoder = SpeechFConvDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + embed_dict=decoder_embed_dict, + convolutions=eval(args.decoder_layers), + out_embed_dim=args.decoder_out_embed_dim, + attention=eval(args.decoder_attention), + dropout=args.dropout, + max_positions=args.max_target_positions, + share_embed=args.share_input_output_embed, + positional_embeddings=args.decoder_positional_embed, + ) + return SpeechFConvModel(encoder, decoder) + + +class SpeechFConvEncoder(FConvEncoder): + """ + Convolutional encoder consisting of `len(convolutions)` layers. + + Args: + conv_layers_before (~fairseq.speech_lstm.ConvBNReLU): convolutions befoe + fconv layers + input_size (int, optional): dim of input to the transformer before being + projected to embed_dim + embed_dim (int, optional): embedding dimension + max_positions (int, optional): maximum supported input sequence length + convolutions (list, optional): the convolutional layer structure. Each + list item `i` corresponds to convolutional layer `i`. Layers are + given as ``(out_channels, kernel_width, [residual])``. Residual + connections are added between layers when ``residual=1`` (which is + the default behavior). + dropout (float, optional): dropout to be applied before each conv layer + """ + + def __init__( + self, conv_layers_before=None, input_size=83, embed_dim=512, + convolutions=((512, 3),) * 20, dropout=0.1, + ): + super(FConvEncoder, self).__init__(None) # no src dictionary + self.dropout = dropout + self.num_attention_layers = None + + self.conv_layers_before = conv_layers_before + self.fc0 = Linear(input_size, embed_dim, dropout=dropout) \ + if input_size != embed_dim else None + + convolutions = extend_conv_spec(convolutions) + in_channels = convolutions[0][0] + self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) + self.projections = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.residuals = [] + + layer_in_channels = [in_channels] + for _, (out_channels, kernel_size, residual) in enumerate(convolutions): + if residual == 0: + residual_dim = out_channels + else: + residual_dim = layer_in_channels[-residual] + self.projections.append(Linear(residual_dim, out_channels) + if residual_dim != out_channels else None) + if kernel_size % 2 == 1: + padding = kernel_size // 2 + else: + padding = 0 + self.convolutions.append( + ConvTBC(in_channels, out_channels * 2, kernel_size, + dropout=dropout, padding=padding) + ) + self.residuals.append(residual) + in_channels = out_channels + layer_in_channels.append(out_channels) + self.fc2 = Linear(in_channels, embed_dim) + + def output_lengths(self, in_lengths): + return in_lengths if self.conv_layers_before is None \ + else self.conv_layers_before.output_lengths(in_lengths) + + def forward(self, src_tokens, src_lengths): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + + Returns: + dict: + - **encoder_out** (tuple): a tuple with two elements, where the + first element is the last encoder layer's output and the + second element is the same quantity summed with the input + embedding (used for attention). The shape of both tensors is + `(batch, src_len, embed_dim)`. + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + """ + if self.conv_layers_before is not None: + x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, + src_lengths) + else: + x, encoder_padding_mask = src_tokens, \ + ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) + + x = F.dropout(x, p=self.dropout, training=self.training) + if self.fc0 is not None: + x = self.fc0(x) + x = F.dropout(x, p=self.dropout, training=self.training) + input_embedding = x + + # project to size of convolution + x = self.fc1(x) + + encoder_padding_mask = encoder_padding_mask.t() # -> T x B + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + residuals = [x] + # temporal convolutions + for proj, conv, res_layer in zip(self.projections, self.convolutions, self.residuals): + if res_layer > 0: + residual = residuals[-res_layer] + residual = residual if proj is None else proj(residual) + else: + residual = None + + if encoder_padding_mask is not None: + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + x = F.dropout(x, p=self.dropout, training=self.training) + if conv.kernel_size[0] % 2 == 1: + # padding is implicit in the conv + x = conv(x) + else: + padding_l = (conv.kernel_size[0] - 1) // 2 + padding_r = conv.kernel_size[0] // 2 + x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) + x = conv(x) + x = F.glu(x, dim=2) + + if residual is not None: + x = (x + residual) * math.sqrt(0.5) + residuals.append(x) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # project back to size of embedding + x = self.fc2(x) + + if encoder_padding_mask is not None: + encoder_padding_mask = encoder_padding_mask.t() # -> B x T + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + # scale gradients (this only affects backward, not forward) + x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) + + # add output to input embedding for attention + y = (x + input_embedding) * math.sqrt(0.5) + + return { + 'encoder_out': (x, y), + 'encoder_padding_mask': encoder_padding_mask, # B x T + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out['encoder_out'] is not None: + encoder_out['encoder_out'] = ( + encoder_out['encoder_out'][0].index_select(0, new_order), + encoder_out['encoder_out'][1].index_select(0, new_order), + ) + if encoder_out['encoder_padding_mask'] is not None: + encoder_out['encoder_padding_mask'] = \ + encoder_out['encoder_padding_mask'].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return int(1e5) + + +class SpeechFConvDecoder(FConvDecoder): + def masked_copy_incremental_state(self, incremental_state, another_state, mask): + state = utils.get_incremental_state(self, incremental_state, 'encoder_out') + if state is None: + assert another_state is None + return + + def mask_copy_state(state, another_state): + if isinstance(state, list): + assert isinstance(another_state, list) and len(state) == len(another_state) + return [mask_copy_state(state_i, another_state_i) \ + for state_i, another_state_i in zip(state, another_state)] + if state is not None: + assert state.size(0) == mask.size(0) and another_state is not None and \ + state.size() == another_state.size() + for _ in range(1, len(state.size())): + mask_unsqueezed = mask.unsqueeze(-1) + return torch.where(mask_unsqueezed, state, another_state) + else: + assert another_state is None + return None + + new_state = tuple(map(mask_copy_state, state, another_state)) + utils.set_incremental_state(self, incremental_state, 'encoder_out', new_state) + + +@register_model_architecture('speech_fconv', 'speech_fconv') +def base_architecture(args): + args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', + '[64, 64, 128, 128]') + args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', + '[(3, 3), (3, 3), (3, 3), (3, 3)]') + args.encoder_conv_strides = getattr(args, 'encoder_conv_strides', + '[(1, 1), (2, 2), (1, 1), (2, 2)]') + args.dropout = getattr(args, 'dropout', 0.1) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20') + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20') + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) + args.decoder_attention = getattr(args, 'decoder_attention', 'True') + args.share_input_output_embed = getattr(args, 'share_input_output_embed', False) + args.decoder_positional_embed = getattr(args, 'decoder_positional_embed', False) + + +@register_model_architecture('speech_fconv', 'speech_fconv_librispeech') +def speech_fconv_librispeech(args): + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) + args.encoder_layers = getattr(args, 'encoder_layers', '[(256, 3)] * 4') + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) + args.decoder_layers = getattr(args, 'decoder_layers', '[(256, 3)] * 3') + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) + base_architecture(args) diff --git a/fairseq/models/speech_transformer.py b/fairseq/models/speech_transformer.py new file mode 100644 index 000000000..e73143d0a --- /dev/null +++ b/fairseq/models/speech_transformer.py @@ -0,0 +1,328 @@ +# Copyright (c) 2019-present, Yiming Wang +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.models import ( + register_model, + register_model_architecture, +) +from fairseq.modules import LayerNorm + +from .speech_lstm import ConvBNReLU + +from .transformer import ( + Embedding, + Linear, + TransformerModel, + TransformerEncoder, + TransformerDecoder, + TransformerEncoderLayer, +) + +import speech_tools.utils as speech_utils + +DEFAULT_MAX_SOURCE_POSITIONS = 9999 +DEFAULT_MAX_TARGET_POSITIONS = 999 + + +@register_model('speech_transformer') +class SpeechTransformerModel(TransformerModel): + """ + Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) + `_. + + Args: + encoder (TransformerEncoder): the encoder + decoder (TransformerDecoder): the decoder + + The Transformer model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.transformer_parser + :prog: + """ + + @classmethod + def hub_models(cls): + raise NotImplementedError + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + TransformerModel.add_args(parser) + parser.add_argument('--encoder-conv-channels', type=str, metavar='EXPR', + help='list of encoder convolution\'s out channels') + parser.add_argument('--encoder-conv-kernel-sizes', type=str, metavar='EXPR', + help='list of encoder convolution\'s kernel sizes') + parser.add_argument('--encoder-conv-strides', type=str, metavar='EXPR', + help='list of encoder convolution\'s strides') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, 'max_source_positions'): + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if not hasattr(args, 'max_target_positions'): + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + dict = task.target_dictionary + + def build_embedding(dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + decoder_embed_tokens = build_embedding( + dict, args.decoder_embed_dim, args.decoder_embed_path + ) + + def eval_str_nested_list_or_tuple(x, type=int): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + if isinstance(x, list): + return list( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + elif isinstance(x, tuple): + return tuple( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + else: + try: + return type(x) + except: + raise ValueError + + out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, + type=int) + kernel_sizes = eval_str_nested_list_or_tuple( + args.encoder_conv_kernel_sizes, type=int) + strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, + type=int) + print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, + task.feat_in_channels)) + assert task.feat_dim % task.feat_in_channels == 0 + conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, + in_channels=task.feat_in_channels) if not out_channels is None else None + + transformer_encoder_input_size = task.feat_dim // task.feat_in_channels + if conv_layers is not None: + for stride in strides: + if isinstance(stride, (list, tuple)): + assert len(stride) > 0 + s = stride[1] if len(stride) > 1 else stride[0] + else: + assert isinstance(stride, int) + s = stride + transformer_encoder_input_size = \ + (transformer_encoder_input_size + s - 1) // s + transformer_encoder_input_size *= out_channels[-1] + + encoder = cls.build_encoder(args, conv_layers_before=conv_layers, + input_size=transformer_encoder_input_size) + decoder = cls.build_decoder(args, dict, decoder_embed_tokens) + return SpeechTransformerModel(encoder, decoder) + + @classmethod + def build_encoder(cls, args, conv_layers_before=None, input_size=83): + return SpeechTransformerEncoder(args, + conv_layers_before=conv_layers_before, input_size=input_size) + + @classmethod + def build_decoder(cls, args, dict, embed_tokens): + return SpeechTransformerDecoder(args, dict, embed_tokens) + + +class SpeechTransformerEncoder(TransformerEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + conv_layers_before (~fairseq.speech_lstm.ConvBNReLU): convolutions befoe + transformer layers + input_size (int, optional): dim of input to the transformer before being + projected to args.encoder_embed_dim + """ + + def __init__(self, args, conv_layers_before=None, input_size=83): + super(TransformerEncoder, self).__init__(None) # no src dictionary + self.register_buffer('version', torch.Tensor([3])) + + self.dropout = args.dropout + + self.conv_layers_before = conv_layers_before + self.fc0 = Linear(input_size, args.encoder_embed_dim) \ + if input_size != args.encoder_embed_dim else None + self.max_source_positions = args.max_source_positions + + self.layers = nn.ModuleList([]) + self.layers.extend([ + TransformerEncoderLayer(args) + for i in range(args.encoder_layers) + ]) + + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def output_lengths(self, in_lengths): + return in_lengths if self.conv_layers_before is None \ + else self.conv_layers_before.output_lengths(in_lengths) + + def forward(self, src_tokens, src_lengths): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + """ + if self.conv_layers_before is not None: + x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, + src_lengths) + else: + x, encoder_padding_mask = src_tokens, \ + ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) + + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + x = F.dropout(x, p=self.dropout, training=self.training) + if self.fc0 is not None: + x = self.fc0(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # encoder layers + for layer in self.layers: + x = layer(x, encoder_padding_mask) + + if self.layer_norm: + x = self.layer_norm(x) + + return { + 'encoder_out': x, # T x B x C + 'encoder_padding_mask': encoder_padding_mask, # B x T + } + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if encoder_out['encoder_out'] is not None: + encoder_out['encoder_out'] = \ + encoder_out['encoder_out'].index_select(1, new_order) + if encoder_out['encoder_padding_mask'] is not None: + encoder_out['encoder_padding_mask'] = \ + encoder_out['encoder_padding_mask'].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return self.max_source_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + for i in range(len(self.layers)): + # update layer norms + self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) + + version_key = '{}.version'.format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + return state_dict + + +class SpeechTransformerDecoder(TransformerDecoder): + def masked_copy_incremental_state(self, incremental_state, another_cached_state, mask): + pass + +@register_model_architecture('speech_transformer', 'speech_transformer') +def base_architecture(args): + args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', + '[64, 64, 128, 128]') + args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', + '[(3, 3), (3, 3), (3, 3), (3, 3)]') + args.encoder_conv_strides = getattr(args, 'encoder_conv_strides', + '[(1, 1), (2, 2), (1, 1), (2, 2)]') + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512) + args.encoder_layers = getattr(args, 'encoder_layers', 6) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) + args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) + args.decoder_layers = getattr(args, 'decoder_layers', 6) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) + args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) + args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) + args.attention_dropout = getattr(args, 'attention_dropout', 0.) + args.activation_dropout = getattr(args, 'activation_dropout', 0.) + args.activation_fn = getattr(args, 'activation_fn', 'relu') + args.dropout = getattr(args, 'dropout', 0.1) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) + args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) + args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) + args.adaptive_input = getattr(args, 'adaptive_input', False) + + args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) + args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) + + +@register_model_architecture('speech_transformer', 'speech_transformer_librispeech') +def speech_transformer_librispeech(args): + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 1) + args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 1) + args.dropout = getattr(args, 'dropout', 0.3) + base_architecture(args) From 9eb10202b3d1447f6dc361dc5f027c505df442da Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 20 Jun 2019 01:12:21 -0400 Subject: [PATCH 024/119] LM arch changes --- examples/asr_librispeech/run.sh | 8 ++++---- examples/asr_wsj/run.sh | 14 ++++++------- fairseq/criterions/cross_entropy_with_wer.py | 2 +- .../label_smoothed_cross_entropy_with_wer.py | 2 +- fairseq/data/token_dictionary.py | 2 +- fairseq/models/speech_lstm.py | 20 +++++++++---------- fairseq/wer.py | 4 ++-- speech_tools/utils.py | 2 +- 8 files changed, 27 insertions(+), 27 deletions(-) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 6ae6fa6c8..14ce148da 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -174,8 +174,8 @@ if [ ${stage} -le 5 ]; then --log-interval 8000 --log-format simple \ --num-workers 0 --max-tokens 30720 --max-sentences 1024 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ - --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ + --distributed-world-size $ngpus --distributed-port 100 \ + --max-epoch 30 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 8000 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ @@ -211,7 +211,7 @@ if [ ${stage} -le 7 ]; then --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 24 \ --valid-subset $valid_subset --max-sentences-valid 48 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port 100 \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ @@ -243,7 +243,7 @@ if [ ${stage} -le 8 ]; then --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ - --path $path --beam 25 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ + --path $path --beam 20 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 3f07108d7..1d7c7a942 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -200,7 +200,7 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then --log-interval 2000 --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 128 \ --valid-subset $valid_subset --max-sentences-valid 256 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port 100 \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ @@ -228,10 +228,10 @@ if [ ${stage} -le 6 ] && $use_wordlm; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 \ --task language_modeling_for_asr --dict $wordlmdict \ --log-interval 2000 --log-format simple \ - --num-workers 0 --max-tokens 6400 --max-sentences 256 \ + --num-workers 0 --max-tokens 6300 --max-sentences 256 \ --valid-subset $valid_subset --max-sentences-valid 512 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ - --max-epoch 20 --optimizer adam --lr 0.001 --weight-decay 1e-05 \ + --distributed-world-size $ngpus --distributed-port 100 \ + --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $wordlmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ @@ -271,14 +271,14 @@ if [ ${stage} -le 8 ]; then --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 \ --valid-subset $valid_subset --max-sentences-valid 64 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port 100 --ddp-backend no_c10d \ --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer \ --label-smoothing 0.05 --smoothing-type temporal \ - --scheduled-sampling-probs 0.4 --start-scheduled-sampling-epoch 11 \ + --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6 \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ --dict $dict --non-lang-syms $nlsyms \ @@ -297,7 +297,7 @@ if [ ${stage} -le 9 ]; then decode_affix=shallow_fusion else path="$path:$wordlmdir/$lm_checkpoint" - opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 5e-7 --coverage-weight 0.01" + opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-8 --coverage-weight 0.01" decode_affix=shallow_fusion_wordlm fi fi diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 399bad979..26c78e5d9 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -159,7 +159,7 @@ def forward(self, model, sample, reduce=True): # if it is a dummy batch (e.g., a "padding" batch in a sharded # dataset), id might exceeds the dataset size; in that case we # just skip it - if id < len( self.valid_tgt_dataset): + if id < len(self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) pred_tokens = dict.string(pred.data[i]) self.scorer.add_evaluation(utt_id, ref_tokens, diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 5ea5fb802..8e00f0bce 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -167,7 +167,7 @@ def forward(self, model, sample, reduce=True): # if it is a dummy batch (e.g., a "padding" batch in a sharded # dataset), id might exceeds the dataset size; in that case we # just skip it - if id < len( self.valid_tgt_dataset): + if id < len(self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) pred_tokens = dict.string(pred.data[i]) self.scorer.add_evaluation(utt_id, ref_tokens, diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 1060faee2..934c0c1c6 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -34,7 +34,7 @@ def string(self, tensor, bpe_symbol=None, escape_unk=False): We overwrite this since we would like to also ignore . """ if torch.is_tensor(tensor) and tensor.dim() == 2: - return '\n'.join(self.string(t) for t in tensor) + return '\n'.join(self.string(t, bpe_symbol, escape_unk) for t in tensor) def token_string(i): if i == self.unk(): diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index b2f9a2bf0..4cf1b1d58 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -384,7 +384,7 @@ def forward(self, src, src_lengths): class SpeechLSTMEncoder(FairseqEncoder): """LSTM encoder.""" def __init__( - self, conv_layers_before=None, input_size=80, hidden_size=512, + self, conv_layers_before=None, input_size=83, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, residual=False, left_pad=False, pretrained_embed=None, padding_value=0., ): @@ -531,7 +531,7 @@ def __init__( self.additional_fc = Linear(hidden_size + encoder_output_units, out_embed_dim) if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined - self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, embed_dim, adaptive_softmax_cutoff, + self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out) elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) @@ -767,21 +767,21 @@ def lstm_lm_wsj(args): @register_model_architecture('lstm_lm', 'lstm_lm_librispeech') def lstm_lm_librispeech(args): args.dropout = getattr(args, 'dropout', 0.0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 360) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 720) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 800) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 800) args.decoder_layers = getattr(args, 'decoder_layers', 4) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 360) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 800) args.share_embed = getattr(args, 'share_embed', True) base_lm_architecture(args) @register_model_architecture('lstm_lm', 'lstm_wordlm_wsj') def lstm_wordlm_wsj(args): - args.dropout = getattr(args, 'dropout', 0.3) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) - args.decoder_layers = getattr(args, 'decoder_layers', 1) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1024) + args.dropout = getattr(args, 'dropout', 0.35) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1200) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1200) + args.decoder_layers = getattr(args, 'decoder_layers', 3) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1200) args.share_embed = getattr(args, 'share_embed', True) args.is_wordlm = True base_lm_architecture(args) diff --git a/fairseq/wer.py b/fairseq/wer.py index e3a89a1d2..c98bc30d3 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -35,11 +35,11 @@ def parse_wer_output_filter(self, wer_output_filter): if line.startswith('#!') or line == '': continue elif line.startswith('s/'): - m = re.match(r's/(\S+)/(\w*)/g', line) + m = re.match(r's/(.+)/(.*)/g', line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) elif line.startswith('s:'): - m = re.match(r's:(\S+):(\w*):g', line) + m = re.match(r's:(.+):(.*):g', line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: diff --git a/speech_tools/utils.py b/speech_tools/utils.py index db42b6b22..01978d0f9 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -248,7 +248,7 @@ def aligned_print(ref, hyp, steps): counter = Counter(steps) wer = float(counter['sub'] + counter['ins'] + counter['del']) / len(ref) \ - * 100 + * 100 if len(ref) > 0 else 0. out_str += 'WER: ' + '{:.2f}%'.format(wer) + '\n' out_str += '\n' From 9745aa79b4655c0f13d52a11d6ce4f20c084f7b9 Mon Sep 17 00:00:00 2001 From: Hang Lyu Date: Thu, 20 Jun 2019 21:16:34 -0400 Subject: [PATCH 025/119] swbd --- examples/asr_swbd/cmd.sh | 20 ++ examples/asr_swbd/conf/fbank.conf | 2 + examples/asr_swbd/conf/pitch.conf | 1 + examples/asr_swbd/local/MSU_single_letter.txt | 1 + examples/asr_swbd/local/dict.patch | 1 + examples/asr_swbd/local/eval2000_data_prep.sh | 1 + examples/asr_swbd/local/extend_segments.pl | 1 + examples/asr_swbd/local/fisher_map_words.pl | 1 + .../asr_swbd/local/format_acronyms_dict.py | 1 + examples/asr_swbd/local/map_acronyms_ctm.py | 1 + .../local/map_acronyms_transcripts.py | 1 + examples/asr_swbd/local/prepare_ctm.py | 80 +++++ examples/asr_swbd/local/rt03_data_prep.sh | 1 + examples/asr_swbd/local/score_sclite.sh | 103 ++++++ .../asr_swbd/local/swbd1_data_download.sh | 1 + examples/asr_swbd/local/swbd1_data_prep.sh | 1 + .../asr_swbd/local/swbd1_fix_speakerid.pl | 1 + examples/asr_swbd/local/swbd1_map_words.pl | 1 + examples/asr_swbd/local/swbd1_prepare_dict.sh | 1 + examples/asr_swbd/local/wer_output_filter | 47 +++ examples/asr_swbd/path.sh | 17 + examples/asr_swbd/run.sh | 318 ++++++++++++++++++ examples/asr_swbd/steps | 1 + examples/asr_swbd/utils | 1 + 24 files changed, 604 insertions(+) create mode 100644 examples/asr_swbd/cmd.sh create mode 100644 examples/asr_swbd/conf/fbank.conf create mode 100644 examples/asr_swbd/conf/pitch.conf create mode 120000 examples/asr_swbd/local/MSU_single_letter.txt create mode 120000 examples/asr_swbd/local/dict.patch create mode 120000 examples/asr_swbd/local/eval2000_data_prep.sh create mode 120000 examples/asr_swbd/local/extend_segments.pl create mode 120000 examples/asr_swbd/local/fisher_map_words.pl create mode 120000 examples/asr_swbd/local/format_acronyms_dict.py create mode 120000 examples/asr_swbd/local/map_acronyms_ctm.py create mode 120000 examples/asr_swbd/local/map_acronyms_transcripts.py create mode 100755 examples/asr_swbd/local/prepare_ctm.py create mode 120000 examples/asr_swbd/local/rt03_data_prep.sh create mode 100755 examples/asr_swbd/local/score_sclite.sh create mode 120000 examples/asr_swbd/local/swbd1_data_download.sh create mode 120000 examples/asr_swbd/local/swbd1_data_prep.sh create mode 120000 examples/asr_swbd/local/swbd1_fix_speakerid.pl create mode 120000 examples/asr_swbd/local/swbd1_map_words.pl create mode 120000 examples/asr_swbd/local/swbd1_prepare_dict.sh create mode 100755 examples/asr_swbd/local/wer_output_filter create mode 100644 examples/asr_swbd/path.sh create mode 100644 examples/asr_swbd/run.sh create mode 120000 examples/asr_swbd/steps create mode 120000 examples/asr_swbd/utils diff --git a/examples/asr_swbd/cmd.sh b/examples/asr_swbd/cmd.sh new file mode 100644 index 000000000..b14280b96 --- /dev/null +++ b/examples/asr_swbd/cmd.sh @@ -0,0 +1,20 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +#export train_cmd="run.pl --mem 4G" +#export cuda_cmd="run.pl --mem 4G --gpu 1" +#export decode_cmd="run.pl --mem 4G" + +# JHU setup +export train_cmd="queue.pl --mem 4G" +export cuda_cmd="queue.pl --mem 4G --gpu 1 --config conf/gpu.conf" +export decode_cmd="queue.pl --mem 4G" diff --git a/examples/asr_swbd/conf/fbank.conf b/examples/asr_swbd/conf/fbank.conf new file mode 100644 index 000000000..752323586 --- /dev/null +++ b/examples/asr_swbd/conf/fbank.conf @@ -0,0 +1,2 @@ +--sample-frequency=16000 +--num-mel-bins=80 diff --git a/examples/asr_swbd/conf/pitch.conf b/examples/asr_swbd/conf/pitch.conf new file mode 100644 index 000000000..e959a19d5 --- /dev/null +++ b/examples/asr_swbd/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/examples/asr_swbd/local/MSU_single_letter.txt b/examples/asr_swbd/local/MSU_single_letter.txt new file mode 120000 index 000000000..9b034a146 --- /dev/null +++ b/examples/asr_swbd/local/MSU_single_letter.txt @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/MSU_single_letter.txt \ No newline at end of file diff --git a/examples/asr_swbd/local/dict.patch b/examples/asr_swbd/local/dict.patch new file mode 120000 index 000000000..e2ead1dcf --- /dev/null +++ b/examples/asr_swbd/local/dict.patch @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/dict.patch \ No newline at end of file diff --git a/examples/asr_swbd/local/eval2000_data_prep.sh b/examples/asr_swbd/local/eval2000_data_prep.sh new file mode 120000 index 000000000..179705396 --- /dev/null +++ b/examples/asr_swbd/local/eval2000_data_prep.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/eval2000_data_prep.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/extend_segments.pl b/examples/asr_swbd/local/extend_segments.pl new file mode 120000 index 000000000..0ff7e3a1a --- /dev/null +++ b/examples/asr_swbd/local/extend_segments.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/extend_segments.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/fisher_map_words.pl b/examples/asr_swbd/local/fisher_map_words.pl new file mode 120000 index 000000000..0b8445fc0 --- /dev/null +++ b/examples/asr_swbd/local/fisher_map_words.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/fisher_map_words.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/format_acronyms_dict.py b/examples/asr_swbd/local/format_acronyms_dict.py new file mode 120000 index 000000000..c88fb9578 --- /dev/null +++ b/examples/asr_swbd/local/format_acronyms_dict.py @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/format_acronyms_dict.py \ No newline at end of file diff --git a/examples/asr_swbd/local/map_acronyms_ctm.py b/examples/asr_swbd/local/map_acronyms_ctm.py new file mode 120000 index 000000000..47c775d1c --- /dev/null +++ b/examples/asr_swbd/local/map_acronyms_ctm.py @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/map_acronyms_ctm.py \ No newline at end of file diff --git a/examples/asr_swbd/local/map_acronyms_transcripts.py b/examples/asr_swbd/local/map_acronyms_transcripts.py new file mode 120000 index 000000000..9d1b9c8b7 --- /dev/null +++ b/examples/asr_swbd/local/map_acronyms_transcripts.py @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/map_acronyms_transcripts.py \ No newline at end of file diff --git a/examples/asr_swbd/local/prepare_ctm.py b/examples/asr_swbd/local/prepare_ctm.py new file mode 100755 index 000000000..705e31ee8 --- /dev/null +++ b/examples/asr_swbd/local/prepare_ctm.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +# CopyRight (c) 2019-present, Hang Lyu +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +# This script is use to transform the word level results to ctm format +# The start_time and end_time of each word is fake +# The format of input is like "en_4156-A_030185-030248 oh yeah" +# The format of output is like "en_4156 A start_time duration oh", and so on + +import argparse +import math +import re +import sys + +def main(): + args = get_args() + convert(args) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""Transform the word level results to ctm format""") + parser.add_argument('ori_result', type=str, + help="The input--word level results.") + parser.add_argument('ctm_result', type=str, + help="The output--ctm format result.") + print(' '.join(sys.argv)) + print(sys.argv) + + args = parser.parse_args() + return args + + +def convert(args): + # read in word level results + with open(args.ori_result, 'r', encoding="utf-8") as f: + content = f.readlines() + # convert each line + split_content = [] # store ctm results + for i, line in enumerate(content): + elements = line.strip().split(' ') + + # The first field contains the information of the utterance + utt_info = elements[0] + infos = re.split("[-_]", utt_info) + utt_id = infos[0] + "_" + infos[1] + channel = infos[2] + start_time = round((int(infos[3])/100.0), 2) + end_time = round((int(infos[4])/100.0), 2) + + # generate ctm format results for each word + time_diff = int(infos[4]) - int(infos[3]) + time_step = round((float(time_diff) / (len(elements) - 1) / 100), 2) + for j, word in enumerate(elements): + start_time_tmp = start_time + time_step * (j - 1) + duration = 0.0 + if j == 0: + continue + elif j == len(elements) - 1: + duration = end_time - start_time_tmp + split_content.append(" ".join([utt_id, channel, + str(round(start_time_tmp,2)), + str(round(duration,2)), word])) + else: + duration = time_step + split_content.append(" ".join([utt_id, channel, + str(round(start_time_tmp,2)), + str(round(duration,2)), word])) + # print + with open(args.ctm_result, 'w', encoding='utf-8') as f: + for line in split_content: + print(line, file=f) + + +if __name__ == "__main__": + main() diff --git a/examples/asr_swbd/local/rt03_data_prep.sh b/examples/asr_swbd/local/rt03_data_prep.sh new file mode 120000 index 000000000..35e8bb102 --- /dev/null +++ b/examples/asr_swbd/local/rt03_data_prep.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/rt03_data_prep.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/score_sclite.sh b/examples/asr_swbd/local/score_sclite.sh new file mode 100755 index 000000000..f965594ef --- /dev/null +++ b/examples/asr_swbd/local/score_sclite.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright 2019-present Hang Lyu + +# begin configuration section. +cmd=run.pl +stage=0 +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 2 ]; then + echo "Usage: local/score_sclite.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + echo " --stage (0|1|2|3) # start scoring script from part-way through." + exit 1; +fi + +data=$1 +dir=$2 + +hubscr=$KALDI_ROOT/tools/sctk/bin/hubscr.pl +[ ! -f $hubscr ] && echo "Cannot find scoring program at $hubscr" && exit 1; +hubdir=`dirname $hubscr` + +for f in $data/stm $data/glm $dir/decoded_results.txt; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + +name=`basename $data`; # e.g. eval2000 + +mkdir -p $dir/scoring/log + +if [ $stage -le 0 ]; then + # prepare the $name.ctm files for test set + python3 local/prepare_ctm.py $dir/decoded_results.txt $dir/scoring/$name.ctm +fi + +if [ $stage -le 1 ]; then + # Remove some stuff we don't want to score, from the ctm. + # the big expression in parentheses contains all the things that get mapped + # by the glm file, into hesitations. + # The -$ expression removes partial words. + # the aim here is to remove all the things that appear in the reference as optionally + # deletable (inside parentheses), as if we delete these there is no loss, while + # if we get them correct there is no gain. + for x in $dir/scoring/$name.ctm; do + cp $x $dir/scoring/tmpf; + cat $dir/scoring/tmpf | grep -i -v -E '\[NOISE|LAUGHTER|VOCALIZED-NOISE\]' | \ + grep -i -v -E '' | \ + grep -i -v -E ' (UH|UM|EH|MM|HM|AH|HUH|HA|ER|OOF|HEE|ACH|EEE|EW)$' | \ + grep -v -- '-$' > $x; + python local/map_acronyms_ctm.py -i $x -o $x.mapped -M data/local/dict_nosp/acronyms.map + cp $x $x.bk + mv $x.mapped $x + done +fi + +# Score the set... +if [ $stage -le 2 ]; then + $cmd $dir/scoring/log/score.log \ + cp $data/stm $dir/scoring/ '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/scoring/stm $dir/scoring/${name}.ctm || exit 1; +fi + +# For eval2000 score the subsets +case "$name" in + eval2000*) + # Score only the, swbd part... + if [ $stage -le 3 ]; then + $cmd $dir/scoring/log/score.swbd.log \ + grep -v '^en_' $data/stm '>' $dir/scoring/stm.swbd '&&' \ + grep -v '^en_' $dir/scoring/${name}.ctm '>' $dir/scoring/${name}.ctm.swbd '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/scoring/stm.swbd $dir/scoring/${name}.ctm.swbd || exit 1; + fi + # Score only the, callhome part... + if [ $stage -le 3 ]; then + $cmd $dir/scoring/log/score.callhm.log \ + grep -v '^sw_' $data/stm '>' $dir/scoring/stm.callhm '&&' \ + grep -v '^sw_' $dir/scoring/${name}.ctm '>' $dir/scoring/${name}.ctm.callhm '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/scoring/stm.callhm $dir/scoring/${name}.ctm.callhm || exit 1; + fi + ;; +rt03* ) + # Score only the swbd part... + if [ $stage -le 3 ]; then + $cmd $dir/scoring/log/score.swbd.log \ + grep -v '^fsh_' $data/stm '>' $dir/scoring/stm.swbd '&&' \ + grep -v '^fsh_' $dir/scoring/${name}.ctm '>' $dir/scoring/${name}.ctm.swbd '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/scoring/stm.swbd $dir/scoring/${name}.ctm.swbd || exit 1; + fi + # Score only the fisher part... + if [ $stage -le 3 ]; then + $cmd $dir/scoring/log/score.fsh.log \ + grep -v '^sw_' $data/stm '>' $dir/scoring/stm.fsh '&&' \ + grep -v '^sw_' $dir/scoring/${name}.ctm '>' $dir/scoring/${name}.ctm.fsh '&&' \ + $hubscr -p $hubdir -V -l english -h hub5 -g $data/glm -r $dir/scoring/stm.fsh $dir/scoring/${name}.ctm.fsh || exit 1; + fi + ;; +esac + +exit 0 diff --git a/examples/asr_swbd/local/swbd1_data_download.sh b/examples/asr_swbd/local/swbd1_data_download.sh new file mode 120000 index 000000000..676f5e0b4 --- /dev/null +++ b/examples/asr_swbd/local/swbd1_data_download.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_data_download.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_data_prep.sh b/examples/asr_swbd/local/swbd1_data_prep.sh new file mode 120000 index 000000000..7faee28eb --- /dev/null +++ b/examples/asr_swbd/local/swbd1_data_prep.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_data_prep.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_fix_speakerid.pl b/examples/asr_swbd/local/swbd1_fix_speakerid.pl new file mode 120000 index 000000000..83a348533 --- /dev/null +++ b/examples/asr_swbd/local/swbd1_fix_speakerid.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_fix_speakerid.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_map_words.pl b/examples/asr_swbd/local/swbd1_map_words.pl new file mode 120000 index 000000000..f35ddcb7f --- /dev/null +++ b/examples/asr_swbd/local/swbd1_map_words.pl @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_map_words.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_prepare_dict.sh b/examples/asr_swbd/local/swbd1_prepare_dict.sh new file mode 120000 index 000000000..2b5a643c7 --- /dev/null +++ b/examples/asr_swbd/local/swbd1_prepare_dict.sh @@ -0,0 +1 @@ +../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_prepare_dict.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/wer_output_filter b/examples/asr_swbd/local/wer_output_filter new file mode 100755 index 000000000..d4ea2c52d --- /dev/null +++ b/examples/asr_swbd/local/wer_output_filter @@ -0,0 +1,47 @@ +#!/bin/sed -f +s:\[noise\]::g +s:\[laughter\]::g +s:\[vocalized-noise\]::g +s:^uh ::g +s: uh$::g +s: uh : :g +s:^um ::g +s: um$::g +s: um : :g +s:^eh ::g +s: eh$::g +s: eh : :g +s:^mm ::g +s: mm$::g +s: mm : :g +s:^hm ::g +s: hm$::g +s: hm : :g +s:^ah ::g +s: ah$::g +s: ah : :g +s:^huh ::g +s: huh$::g +s: huh : :g +s:^ha ::g +s: ha$::g +s: ha : :g +s:^er ::g +s: er$::g +s: er : :g +s:^oof ::g +s: oof$::g +s: oof : :g +s:^hee ::g +s: hee$::g +s: hee : :g +s:^ach ::g +s: ach$::g +s: ach : :g +s:^eee ::g +s: eee$::g +s: eee : :g +s:^ew ::g +s: ew$::g +s: ew : :g + diff --git a/examples/asr_swbd/path.sh b/examples/asr_swbd/path.sh new file mode 100644 index 000000000..3290c7576 --- /dev/null +++ b/examples/asr_swbd/path.sh @@ -0,0 +1,17 @@ +MAIN_ROOT=$PWD/../.. +KALDI_ROOT=$MAIN_ROOT/speech_tools/kaldi + +# BEGIN from kaldi path.sh +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C +# END + +export PATH=~/anaconda3/bin:$PATH +export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH +export PATH=$MAIN_ROOT/speech_tools/sentencepiece/build/src:$PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH +export PYTHONUNBUFFERED=1 + diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh new file mode 100644 index 000000000..24f56f926 --- /dev/null +++ b/examples/asr_swbd/run.sh @@ -0,0 +1,318 @@ +#!/bin/bash + +# Copyright (c) 2019-present, Hang Lyu +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +set -e -o pipefail + +stage=0 +ngpus=2 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid + +# E2E model related +affix= +train_set=train_nodup +valid_set=train_dev +test_sets="train_dev eval2000 rt03" +checkpoint=checkpoint_best.pt + +validate_on_train_subset=false # for monitoring E2E model training + +# LM related +lm_affix= +lm_checkpoint=checkpoint_best.pt +lm_shallow_fusion=true # no LM fusion if false +sentencepiece_vocabsize=1000 +sentencepiece_type=unigram + +# data related +dumpdir=data/dump # directory to dump full features +swbd1_dir= +eval2000_dir= +rt03_dir= +fisher_dirs= + +if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then + swbd1_dir=/export/corpora3/LDC/LDC97S62 + eval2000_dir="/export/corpora2/LDC/LDC2002S09/hub5e_00 /export/corpora2/LDC/LDC2002T43" + rt03_dir=/export/corpora/LDC/LDC2007S10 + fisher_dirs="/export/corpora3/LDC/LDC2004T19/fe_03_p1_tran/ /export/corpora3/LDC/LDC2005T19/fe_03_p2_tran/" +fi +train_subset_size=500 # for validation if validate_on_train_subset is set to true +kaldi_scoring=true + +# feature configuration +do_delta=false + + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} +wordlmdir=exp/wordlm_lstm${wordlm_affix:+_${wordlm_affix}} +dir=exp/lstm${affix:+_$affix} + +if [ $stage -le 0 ]; then + echo "Stage 0: Data Preparation" + local/swbd1_data_download.sh ${swbd1_dir} + local/swbd1_prepare_dict.sh + local/swbd1_data_prep.sh ${swbd1_dir} + local/eval2000_data_prep.sh ${eval2000_dir} + local/rt03_data_prep.sh ${rt03_dir} + # upsample audio from 8k to 16k to compare with espnet. (It may affact the + # performance + for x in train eval2000 rt03; do + sed -i.bak -e "s/$/ sox -R -t wav - -t wav - rate 16000 dither | /" data/${x}/wav.scp + done + # normalize eval2000 ant rt03 texts by + # 1) convert upper to lower + # 2) remove tags (%AH) (%HESITATION) (%UH) + # 3) remove + # 4) remove "(" or ")" + for x in eval2000 rt03; do + cp data/${x}/text data/${x}/text.org + paste -d "" \ + <(cut -f 1 -d" " data/${x}/text.org) \ + <(awk '{$1=""; print tolower($0)}' data/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") \ + | sed -e 's/\s\+/ /g' > data/${x}/text + # rm data/${x}/text.org + done + echo "Succeeded in formatting data." +fi + + +train_feat_dir=$dumpdir/$train_set/delta${do_delta}; mkdir -p $train_feat_dir +valid_feat_dir=$dumpdir/$valid_set/delta${do_delta}; mkdir -p $valid_feat_dir +if [ $stage -le 1 ]; then + echo "Stage 1: Feature Generation" + fbankdir=fbank + # Generate the fbank features; by default 80-dimensional fbanks with pitch on + # each frame + for x in train eval2000 rt03; do + steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj 32 --write-utt2num-frames true \ + data/$x exp/make_fbank/$x $fbankdir + utils/fix_data_dir.sh data/$x + done + + utils/subset_data_dir.sh --first data/train 4000 data/train_dev # 5hr 6min + n=$[`cat data/train/segments | wc -l` - 4000] + utils/subset_data_dir.sh --last data/train $n data/train_nodev + utils/data/remove_dup_utts.sh 300 data/train_nodev data/train_nodup # 286hr + + # compute global CMVN + compute-cmvn-stats scp:data/$train_set/feats.scp data/$train_set/cmvn.ark + + # dump features for training + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $train_feat_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$USER/espnet-data/egs/swbd/asr1/dump/$train_set/delta${do_delta}/storage \ + $train_feat_dir/storage + fi + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $valid_feat_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$USER/espnet-data/egs/swbd/asr1/dump/$valid_set/delta${do_delta}/storage \ + $valid_feat_dir/storage + fi + dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ + data/$train_set/feats.scp data/$train_set/cmvn.ark exp/dump_feats/train $train_feat_dir + dump.sh --cmd "$train_cmd" --nj 10 --do_delta $do_delta \ + data/$valid_set/feats.scp data/$train_set/cmvn.ark exp/dump_feats/dev $valid_feat_dir + for rtask in $test_sets; do + test_feat_dir=$dumpdir/$rtask/delta${do_delta}; mkdir -p $test_feat_dir + dump.sh --cmd "$train_cmd" --nj 10 --do_delta $do_delta \ + data/$rtask/feats.scp data/$train_set/cmvn.ark exp/dump_feats/recog/$rtask \ + $test_feat_dir + done + echo "Succeeded in generating features for train_nodup, train_dev, eval2000 and rt03" +fi + + +dict=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize}_units.txt +sentencepiece_model=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize} +nlsyms=data/lang/non_lang_syms.txt +lmdatadir=data/lm_text +if [ $stage -le 2 ]; then + echo "Stage 2: Dictionary Preparation and Text Tokenization" + mkdir -p data/lang + mkdir -p $lmdatadir + + echo "Making a non-linguistic symbol list..." + train_text=data/$train_set/text + cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "\[" > $nlsyms + cat $nlsyms + + echo "Preparing extra corpus for subword LM training..." + if [ -f $lmdatadir/fisher_text0 ]; then + rm -rf $lmdatadir/fisher_text0 + fi + for x in $fisher_dirs; do + [ ! -d $x/data/trans ] \ + && "Cannot find transcripts in Fisher directory $x" && exit 1; + cat $x/data/trans/*/*.txt | \ + grep -v ^# | grep -v ^$ | cut -d' ' -f4- >> $lmdatadir/fisher_text0 + done + cat $lmdatadir/fisher_text0 | local/fisher_map_words.pl | \ + sed 's/^[ \t]*//'> $lmdatadir/fisher_text + + echo "Training sentencepiece model..." + cut -f 2- -d" " data/$train_set/text | \ + cat - $lmdatadir/fisher_text > data/lang/input + spm_train --bos_id=-1 --pad_id=0 --eos_id=1 --unk_id=2 --input=data/lang/input \ + --vocab_size=$((sentencepiece_vocabsize+3)) --character_coverage=1.0 \ + --model_type=$sentencepiece_type --model_prefix=$sentencepiece_model \ + --input_sentence_size=10000000 \ + --user_defined_symbols=$(cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "\[" | tr "\n" "," | sed 's/,$//') + + echo "Making a dictionary and tokenizing text for train/valid/test sets..." + for dataset in $train_set $test_sets; do # validation is included in tests + text=data/$dataset/text + token_text=data/$dataset/token_text + spm_encode --model=${sentencepiece_model}.model --output_format=piece \ + <(cut -f 2- -d' ' $text) | paste -d" " <(cut -f 1 -d' ' $text) - > $token_text + # prepare dict with train_set + if [ "$dataset" == "$train_set" ]; then + cut -f 2- -d" " $token_text | tr " " "\n" | grep -v -e '^\s*$' | sort | \ + uniq -c | awk '{print $2,$1}' > $dict + wc -l $dict + fi + done + + echo "Preparing text for subword LM..." + mkdir -p $lmdatadir + for dataset in $train_set $test_sets; do + token_text=data/$dataset/token_text + cut -f 2- -d" " $token_text > $lmdatadir/$dataset.tokens + done + + echo "Preparing extra corpus for subword LM training..." + cat $lmdatadir/fisher_text |\ + spm_encode --model=${sentencepiece_model}.model --output_format=piece |\ + cat $lmdatadir/$train_set.tokens - > $lmdatadir/train.tokens +fi + + +lmdict=$dict +if [ $stage -le 3 ]; then + echo "Stage 3: Text Binarization for subword LM Training" + mkdir -p $lmdatadir/logs + for dataset in $test_sets; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done + test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') + ${decode_cmd} $lmdatadir/logs/preprocess.log \ + python3 ../../preprocess.py --task language_modeling_for_asr \ + --workers 50 --srcdict $lmdict --only-source \ + --trainpref $lmdatadir/train.tokens \ + --validpref $lmdatadir/$valid_set.tokens \ + --testpref $test_paths \ + --destdir $lmdatadir +fi + + +[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) || \ + echo "Unable to get $ngpus GPUs" +[ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; +[ $(echo $free_gpu | sed 's/,/ /g' | awk '{print NF}') -ne "$ngpus" ] && \ + echo "number of GPU ids in --free-gpu=$free_gpu does not match --ngpus=$ngpus" && exit 1; + + +if [ $stage -le 4 ]; then + echo "Stage 4: subword LM Training" + valid_subset=valid + mkdir -p $lmdir/logs + log_file=$lmdir/logs/train.log + [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ + --task language_modeling_for_asr --dict $lmdict \ + --log-interval 500 --log-format simple \ + --num-workers 0 --max-tokens 30720 --max-sentences 1024 \ + --valid-subset $valid_subset --max-sentences-valid 1536 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ + --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ + --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 500 \ + --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ + --arch lstm_lm_librispeech --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file +fi + + +if [ $stage -le 5 ]; then + echo "Stage 5: subword LM Evaluation" + gen_set_array=(test) + num=$(echo $test_sets | awk '{print NF-1}') + for i in $(seq $num); do gen_set_array[$i]="test$i"; done #gen_set_array=(test test1 test2) + test_set_array=($test_sets) + for i in $(seq 0 $num); do + log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log + python3 ../../eval_lm.py $lmdatadir --cpu \ + --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ + --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ + --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file + done +fi + + +train_feat=$train_feat_dir/feats.scp +train_token_text=data/$train_set/token_text +valid_feat=$valid_feat_dir/feats.scp +valid_token_text=data/$valid_set/token_text +if [ $stage -le 6 ]; then + echo "Stage 6: Model Training" + valid_subset=valid + opts="" + [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" + mkdir -p $dir/logs + log_file=$dir/logs/train.log + [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ + --log-interval 1500 --log-format simple --print-training-sample-interval 2000 --ddp-backend "no_c10d" \ + --num-workers 0 --max-tokens 26000 --max-sentences 24 \ + --valid-subset $valid_subset --max-sentences-valid 48 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ + --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ + --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ + --label-smoothing 0.1 --smoothing-type uniform \ + --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ + --train-feat-files $train_feat --train-text-files $train_token_text \ + --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ + --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ + --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file +fi + + +if [ $stage -le 7 ]; then + echo "Stage 7: Decoding" + opts="" + path=$dir/$checkpoint + decode_affix= + if $lm_shallow_fusion; then + path="$path:$lmdir/$lm_checkpoint" + opts="$opts --lm-weight 0.3 --coverage-weight 0.0" + decode_affix=shallow_fusion + fi + [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" + for dataset in $test_sets; do + feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp + text=data/$dataset/token_text + CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ + --max-tokens 16000 --max-sentences 24 --num-shards 1 --shard-id 0 \ + --test-feat-files $feat --test-text-files $text \ + --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ + --max-source-positions 9999 --max-target-positions 999 \ + --path $path --beam 30 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ + --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ + 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + done + if $kaldi_scoring; then + echo "verify WER by scoring with kaldi..." + # The word level results are stored in decode_dir/decoded_results.txt + local/score_sclite.sh data/eval2000 $dir/decode_eval2000${decode_affix:+_${decode_affix}} + local/score_sclite.sh data/rt03 $dir/decode_rt03${decode_affix:+_${decode_affix}} + fi +fi diff --git a/examples/asr_swbd/steps b/examples/asr_swbd/steps new file mode 120000 index 000000000..ec9b528ac --- /dev/null +++ b/examples/asr_swbd/steps @@ -0,0 +1 @@ +../../speech_tools/kaldi/egs/wsj/s5/steps \ No newline at end of file diff --git a/examples/asr_swbd/utils b/examples/asr_swbd/utils new file mode 120000 index 000000000..ea44d93b9 --- /dev/null +++ b/examples/asr_swbd/utils @@ -0,0 +1 @@ +../../speech_tools/kaldi/egs/wsj/s5/utils \ No newline at end of file From 7ae022204bec2df5af714d3f19c52d34828ee48b Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 26 Jun 2019 15:30:46 -0400 Subject: [PATCH 026/119] add coverage to libripspeech recipe; make best metric for ASR configurable by using --best-checkpoint-metric; code adaptation/changes according to the commits from Jun 26, 2019 to Jul 3, 2019 --- examples/asr_librispeech/run.sh | 8 +++--- examples/asr_swbd/run.sh | 2 +- examples/asr_wsj/run.sh | 2 +- fairseq/data/token_dictionary.py | 13 +++++++++- speech_recognize.py | 4 +-- speech_train.py | 42 +++++++++++++++++++------------- 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 14ce148da..895027bac 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -215,7 +215,7 @@ if [ ${stage} -le 7 ]; then --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ - --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ + --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ @@ -232,18 +232,18 @@ if [ ${stage} -le 8 ]; then decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.35 --coverage-weight 0.0" + opts="$opts --lm-weight 0.4 --coverage-weight 0.015" decode_affix=shallow_fusion fi for dataset in $test_set; do feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --max-tokens 16000 --max-sentences 24 --num-shards 1 --shard-id 0 \ + --max-tokens 15000 --max-sentences 24 --num-shards 1 --shard-id 0 \ --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ - --path $path --beam 20 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ + --path $path --beam 35 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 24f56f926..5b9a21c74 100644 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -275,7 +275,7 @@ if [ $stage -le 6 ]; then --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ - --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ + --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 1d7c7a942..eedc5194d 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -275,7 +275,7 @@ if [ ${stage} -le 8 ]; then --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ - --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer \ --label-smoothing 0.05 --smoothing-type temporal \ --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6 \ diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 934c0c1c6..a91d89d09 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -14,7 +14,15 @@ class TokenDictionary(Dictionary): """A mapping from symbols to consecutive integers""" - def __init__(self, pad='', eos='', unk='', bos='', space=''): + def __init__( + self, + pad='', + eos='', + unk='', + bos='', + space='', + extra_special_symbols=None, + ): self.unk_word, self.pad_word, self.eos_word, self.bos_word, self.space_word = \ unk, pad, eos, bos, space self.symbols = [] @@ -23,6 +31,9 @@ def __init__(self, pad='', eos='', unk='', bos='', space=' args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) and \ - (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch and \ - epoch_itr._next_epoch_itr is not None)) and trainer.get_num_updates() < max_update: + while ( + (lr > args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) + and ( + epoch_itr.epoch < max_epoch or ( + epoch_itr.epoch == max_epoch + and epoch_itr._next_epoch_itr is not None + ) + ) + and trainer.get_num_updates() < max_update + ): # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: - valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, valid_subsets) + valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) # only use first validation wer to update the learning rate - lr = trainer.lr_step(epoch_itr.epoch, valid_wers[0]) + lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if len(args.train_feat_files) > 1: # sharded data: get train iterator for next epoch @@ -156,8 +163,8 @@ def train(args, trainer, task, epoch_itr): and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): - valid_losses, valid_wers = validate(args, trainer, task, epoch_itr, valid_subsets) - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_wers[0]) + valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break @@ -205,12 +212,11 @@ def get_training_stats(trainer): def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] - valid_wers = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), - max_tokens=args.max_tokens, + max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), @@ -260,14 +266,16 @@ def validate(args, trainer, task, epoch_itr, subsets): # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): - stats[k] = meter.avg + stats[k] = meter if k == 'wer' or k == 'cer' else meter.avg if hasattr(checkpoint_utils.save_checkpoint, 'best'): - stats['best_wer'] = min(checkpoint_utils.save_checkpoint.best, stats['wer']) + stats['best_' + args.best_checkpoint_metric] = min( + checkpoint_utils.save_checkpoint.best, + stats[args.best_checkpoint_metric].avg, + ) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats['loss'].avg) - valid_wers.append(stats['wer']) - return valid_losses, valid_wers + valid_losses.append(stats[args.best_checkpoint_metric].avg) + return valid_losses def get_valid_stats(trainer): From 6ab3bd5492299e5e3ef23f644d777828f1f8d4d2 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 4 Jul 2019 12:29:43 -0400 Subject: [PATCH 027/119] fix swbd recipe; code adaptation/changes according to the commits from Jul 17, 2019 to Jul 19, 2019 --- examples/asr_librispeech/path.sh | 1 - examples/asr_librispeech/run.sh | 2 +- examples/asr_swbd/local/prepare_ctm.py | 70 ++++++++------- examples/asr_swbd/local/score.sh | 35 ++++++++ examples/asr_swbd/local/score_basic.sh | 44 ++++++++++ examples/asr_swbd/local/score_sclite.sh | 7 +- examples/asr_swbd/local/wer_output_filter | 44 +--------- examples/asr_swbd/path.sh | 1 - examples/asr_swbd/run.sh | 87 +++++++++---------- examples/asr_wsj/local/score.sh | 10 +-- examples/asr_wsj/path.sh | 1 - fairseq/criterions/cross_entropy_with_wer.py | 4 +- .../label_smoothed_cross_entropy_with_wer.py | 58 +++++++++---- fairseq/models/speech_lstm.py | 27 ++++++ fairseq/tasks/speech_recognition.py | 1 + speech_recognize.py | 2 +- speech_tools/utils.py | 5 ++ 17 files changed, 246 insertions(+), 153 deletions(-) create mode 100755 examples/asr_swbd/local/score.sh create mode 100755 examples/asr_swbd/local/score_basic.sh mode change 100644 => 100755 examples/asr_swbd/run.sh diff --git a/examples/asr_librispeech/path.sh b/examples/asr_librispeech/path.sh index 3290c7576..d0ebe2157 100644 --- a/examples/asr_librispeech/path.sh +++ b/examples/asr_librispeech/path.sh @@ -14,4 +14,3 @@ export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH export PATH=$MAIN_ROOT/speech_tools/sentencepiece/build/src:$PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH export PYTHONUNBUFFERED=1 - diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 895027bac..28846bc7c 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -211,7 +211,7 @@ if [ ${stage} -le 7 ]; then --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 24 \ --valid-subset $valid_subset --max-sentences-valid 48 \ - --distributed-world-size $ngpus --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port 100 --ddp-backend no_c10d \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ diff --git a/examples/asr_swbd/local/prepare_ctm.py b/examples/asr_swbd/local/prepare_ctm.py index 705e31ee8..21eea3742 100755 --- a/examples/asr_swbd/local/prepare_ctm.py +++ b/examples/asr_swbd/local/prepare_ctm.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 - # CopyRight (c) 2019-present, Hang Lyu +# 2019-present, Yiming Wang +# All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights @@ -16,65 +17,70 @@ import re import sys -def main(): - args = get_args() - convert(args) - -def get_args(): +def get_parser(): parser = argparse.ArgumentParser( description="""Transform the word level results to ctm format""") + # fmt: off parser.add_argument('ori_result', type=str, - help="The input--word level results.") + help='The input: original decoded results.') parser.add_argument('ctm_result', type=str, - help="The output--ctm format result.") - print(' '.join(sys.argv)) - print(sys.argv) + help='The output: transformed results in ctm format.') + # fmt: on - args = parser.parse_args() - return args + return parser -def convert(args): +def main(args): # read in word level results - with open(args.ori_result, 'r', encoding="utf-8") as f: + with open(args.ori_result, 'r', encoding='utf-8') as f: content = f.readlines() # convert each line split_content = [] # store ctm results for i, line in enumerate(content): - elements = line.strip().split(' ') + elements = line.strip().split() # The first field contains the information of the utterance utt_info = elements[0] - infos = re.split("[-_]", utt_info) - utt_id = infos[0] + "_" + infos[1] + infos = re.split('[-_]', utt_info) + utt_id = infos[0] + '_' + infos[1] channel = infos[2] - start_time = round((int(infos[3])/100.0), 2) - end_time = round((int(infos[4])/100.0), 2) + start_time = round((int(infos[3]) / 100.0), 2) + end_time = round((int(infos[4]) / 100.0), 2) # generate ctm format results for each word time_diff = int(infos[4]) - int(infos[3]) - time_step = round((float(time_diff) / (len(elements) - 1) / 100), 2) + time_step = round((float(time_diff) / (len(elements) - 1) / 100), 2) \ + if len(elements) > 1 else 0 for j, word in enumerate(elements): - start_time_tmp = start_time + time_step * (j - 1) + start_time_cur = start_time + time_step * (j - 1) duration = 0.0 if j == 0: continue elif j == len(elements) - 1: - duration = end_time - start_time_tmp - split_content.append(" ".join([utt_id, channel, - str(round(start_time_tmp,2)), - str(round(duration,2)), word])) + duration = end_time - start_time_cur + split_content.append( + ' '.join([utt_id, channel, str(round(start_time_cur, 2)), + str(round(duration, 2)), word]) + ) else: duration = time_step - split_content.append(" ".join([utt_id, channel, - str(round(start_time_tmp,2)), - str(round(duration,2)), word])) - # print + split_content.append( + ' '.join([utt_id, channel, str(round(start_time_cur, 2)), + str(round(duration, 2)), word]) + ) + if j == 0: + split_content.append( + ' '.join([utt_id, channel, str(round(start_time, 2)), + str(round(time_diff, 2)), '[noise]']) + ) + with open(args.ctm_result, 'w', encoding='utf-8') as f: for line in split_content: - print(line, file=f) + f.write(line + '\n') -if __name__ == "__main__": - main() +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/asr_swbd/local/score.sh b/examples/asr_swbd/local/score.sh new file mode 100755 index 000000000..78739eae4 --- /dev/null +++ b/examples/asr_swbd/local/score.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright (c) 2012, Johns Hopkins University (Author: Daniel Povey) +# 2019-present, Yiming Wang +# All rights reserved. + + +orig_args= +for x in "$@"; do orig_args="$orig_args '$x'"; done + +# begin configuration section. we include all the options that score_sclite.sh or +# score_basic.sh might need, or parse_options.sh will die. +# CAUTION: these default values do not have any effect because of the +# way pass things through to the scripts that this script calls. +cmd=run.pl +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 2 ]; then + echo "Usage: local/score.sh [options] " && exit; + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + exit 1; +fi + +data=$1 + +if [ -f $data/stm ]; then # use sclite scoring. + echo "$data/stm exists: using local/score_sclite.sh" + eval local/score_sclite.sh $orig_args +else + echo "$data/stm does not exist: using local/score_basic.sh" + eval local/score_basic.sh $orig_args +fi diff --git a/examples/asr_swbd/local/score_basic.sh b/examples/asr_swbd/local/score_basic.sh new file mode 100755 index 000000000..2c5e2d902 --- /dev/null +++ b/examples/asr_swbd/local/score_basic.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Copyright (c) 2012, Johns Hopkins University (Author: Daniel Povey) +# 2019-present, Yiming Wang +# Apache 2.0 + +# begin configuration section. +cmd=run.pl +#end configuration section. + +[ -f ./path.sh ] && . ./path.sh +. parse_options.sh || exit 1; + +if [ $# -ne 2 ]; then + echo "Usage: local/score_basic.sh [--cmd (run.pl|queue.pl...)] " + echo " Options:" + echo " --cmd (run.pl|queue.pl...) # specify how to run the sub-processes." + exit 1; +fi + +data=$1 +dir=$2 + +for f in $data/text $dir/decoded_results.txt; do + [ ! -f $f ] && echo "$0: expecting file $f to exist" && exit 1; +done + + +function filter_text { + perl -e 'foreach $w (@ARGV) { $bad{$w} = 1; } + while() { @A = split(" ", $_); $id = shift @A; print "$id "; + foreach $a (@A) { if (!defined $bad{$a}) { print "$a "; }} print "\n"; }' \ + '[noise]' '[laughter]' '[vocalized-noise]' '' '%hesitation' +} + +mkdir -p $dir/scoring/log +filter_text <$data/text >$dir/scoring/test_filt.txt || exit 1; +filter_text <$dir/decoded_results.txt >$dir/scoring/hyp_filt.txt || exit 1; + +$cmd $dir/scoring/log/score.log \ + compute-wer --text --mode=present \ + ark:$dir/scoring/test_filt.txt ark:$dir/scoring/hyp_filt.txt ">&" \ + $dir/scoring/wer || exit 1; + +exit 0 diff --git a/examples/asr_swbd/local/score_sclite.sh b/examples/asr_swbd/local/score_sclite.sh index f965594ef..2c18b6787 100755 --- a/examples/asr_swbd/local/score_sclite.sh +++ b/examples/asr_swbd/local/score_sclite.sh @@ -1,5 +1,8 @@ #!/bin/bash -# Copyright 2019-present Hang Lyu +# Copyright (c) 2012, Johns Hopkins University (Author: Daniel Povey) +# 2019-present, Hang Lyu +# 2019-present, Yiming Wang +# Apache 2.0 # begin configuration section. cmd=run.pl @@ -34,7 +37,7 @@ mkdir -p $dir/scoring/log if [ $stage -le 0 ]; then # prepare the $name.ctm files for test set - python3 local/prepare_ctm.py $dir/decoded_results.txt $dir/scoring/$name.ctm + local/prepare_ctm.py $dir/decoded_results.txt $dir/scoring/$name.ctm fi if [ $stage -le 1 ]; then diff --git a/examples/asr_swbd/local/wer_output_filter b/examples/asr_swbd/local/wer_output_filter index d4ea2c52d..ac9dda18e 100755 --- a/examples/asr_swbd/local/wer_output_filter +++ b/examples/asr_swbd/local/wer_output_filter @@ -2,46 +2,6 @@ s:\[noise\]::g s:\[laughter\]::g s:\[vocalized-noise\]::g -s:^uh ::g -s: uh$::g -s: uh : :g -s:^um ::g -s: um$::g -s: um : :g -s:^eh ::g -s: eh$::g -s: eh : :g -s:^mm ::g -s: mm$::g -s: mm : :g -s:^hm ::g -s: hm$::g -s: hm : :g -s:^ah ::g -s: ah$::g -s: ah : :g -s:^huh ::g -s: huh$::g -s: huh : :g -s:^ha ::g -s: ha$::g -s: ha : :g -s:^er ::g -s: er$::g -s: er : :g -s:^oof ::g -s: oof$::g -s: oof : :g -s:^hee ::g -s: hee$::g -s: hee : :g -s:^ach ::g -s: ach$::g -s: ach : :g -s:^eee ::g -s: eee$::g -s: eee : :g -s:^ew ::g -s: ew$::g -s: ew : :g +s:::g +s:%hesitation::g diff --git a/examples/asr_swbd/path.sh b/examples/asr_swbd/path.sh index 3290c7576..d0ebe2157 100644 --- a/examples/asr_swbd/path.sh +++ b/examples/asr_swbd/path.sh @@ -14,4 +14,3 @@ export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH export PATH=$MAIN_ROOT/speech_tools/sentencepiece/build/src:$PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH export PYTHONUNBUFFERED=1 - diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh old mode 100644 new mode 100755 index 5b9a21c74..a99b0c094 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -1,6 +1,7 @@ #!/bin/bash # Copyright (c) 2019-present, Hang Lyu +# 2019-present, Yiming Wang # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in @@ -10,7 +11,7 @@ set -e -o pipefail stage=0 -ngpus=2 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid # E2E model related @@ -20,8 +21,6 @@ valid_set=train_dev test_sets="train_dev eval2000 rt03" checkpoint=checkpoint_best.pt -validate_on_train_subset=false # for monitoring E2E model training - # LM related lm_affix= lm_checkpoint=checkpoint_best.pt @@ -42,8 +41,6 @@ if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then rt03_dir=/export/corpora/LDC/LDC2007S10 fisher_dirs="/export/corpora3/LDC/LDC2004T19/fe_03_p1_tran/ /export/corpora3/LDC/LDC2005T19/fe_03_p2_tran/" fi -train_subset_size=500 # for validation if validate_on_train_subset is set to true -kaldi_scoring=true # feature configuration do_delta=false @@ -85,7 +82,6 @@ if [ $stage -le 0 ]; then echo "Succeeded in formatting data." fi - train_feat_dir=$dumpdir/$train_set/delta${do_delta}; mkdir -p $train_feat_dir valid_feat_dir=$dumpdir/$valid_set/delta${do_delta}; mkdir -p $valid_feat_dir if [ $stage -le 1 ]; then @@ -110,12 +106,12 @@ if [ $stage -le 1 ]; then # dump features for training if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $train_feat_dir/storage ]; then utils/create_split_dir.pl \ - /export/b{14,15,16,17}/$USER/espnet-data/egs/swbd/asr1/dump/$train_set/delta${do_delta}/storage \ + /export/b1{4,5,6,7}/$USER/fairseq-data/egs/asr_swbd/dump/$train_set/delta${do_delta}/storage \ $train_feat_dir/storage fi if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $valid_feat_dir/storage ]; then utils/create_split_dir.pl \ - /export/b{14,15,16,17}/$USER/espnet-data/egs/swbd/asr1/dump/$valid_set/delta${do_delta}/storage \ + /export/b1{4,5,6,7}/$USER/fairseq-data/egs/asr_swbd/dump/$valid_set/delta${do_delta}/storage \ $valid_feat_dir/storage fi dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ @@ -131,7 +127,6 @@ if [ $stage -le 1 ]; then echo "Succeeded in generating features for train_nodup, train_dev, eval2000 and rt03" fi - dict=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize}_units.txt sentencepiece_model=data/lang/${train_set}_${sentencepiece_type}${sentencepiece_vocabsize} nlsyms=data/lang/non_lang_syms.txt @@ -154,7 +149,7 @@ if [ $stage -le 2 ]; then [ ! -d $x/data/trans ] \ && "Cannot find transcripts in Fisher directory $x" && exit 1; cat $x/data/trans/*/*.txt | \ - grep -v ^# | grep -v ^$ | cut -d' ' -f4- >> $lmdatadir/fisher_text0 + grep -v '^#' | grep -v '^$' | cut -d' ' -f4- >> $lmdatadir/fisher_text0 done cat $lmdatadir/fisher_text0 | local/fisher_map_words.pl | \ sed 's/^[ \t]*//'> $lmdatadir/fisher_text @@ -168,18 +163,12 @@ if [ $stage -le 2 ]; then --input_sentence_size=10000000 \ --user_defined_symbols=$(cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "\[" | tr "\n" "," | sed 's/,$//') - echo "Making a dictionary and tokenizing text for train/valid/test sets..." + echo "Tokenizing text for train/valid/test sets..." for dataset in $train_set $test_sets; do # validation is included in tests text=data/$dataset/text token_text=data/$dataset/token_text spm_encode --model=${sentencepiece_model}.model --output_format=piece \ <(cut -f 2- -d' ' $text) | paste -d" " <(cut -f 1 -d' ' $text) - > $token_text - # prepare dict with train_set - if [ "$dataset" == "$train_set" ]; then - cut -f 2- -d" " $token_text | tr " " "\n" | grep -v -e '^\s*$' | sort | \ - uniq -c | awk '{print $2,$1}' > $dict - wc -l $dict - fi done echo "Preparing text for subword LM..." @@ -193,8 +182,12 @@ if [ $stage -le 2 ]; then cat $lmdatadir/fisher_text |\ spm_encode --model=${sentencepiece_model}.model --output_format=piece |\ cat $lmdatadir/$train_set.tokens - > $lmdatadir/train.tokens -fi + echo "Making a dictionary with swbd+fisher text" + cat $lmdatadir/train.tokens | tr " " "\n" | grep -v -e '^\s*$' | sort | \ + uniq -c | awk '{print $2,$1}' > $dict + wc -l $dict +fi lmdict=$dict if [ $stage -le 3 ]; then @@ -211,14 +204,12 @@ if [ $stage -le 3 ]; then --destdir $lmdatadir fi - [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) || \ echo "Unable to get $ngpus GPUs" [ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; [ $(echo $free_gpu | sed 's/,/ /g' | awk '{print NF}') -ne "$ngpus" ] && \ echo "number of GPU ids in --free-gpu=$free_gpu does not match --ngpus=$ngpus" && exit 1; - if [ $stage -le 4 ]; then echo "Stage 4: subword LM Training" valid_subset=valid @@ -235,10 +226,9 @@ if [ $stage -le 4 ]; then --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 500 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ - --arch lstm_lm_librispeech --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file + --arch lstm_lm_swbd --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file fi - if [ $stage -le 5 ]; then echo "Stage 5: subword LM Evaluation" gen_set_array=(test) @@ -247,14 +237,13 @@ if [ $stage -le 5 ]; then test_set_array=($test_sets) for i in $(seq 0 $num); do log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log - python3 ../../eval_lm.py $lmdatadir --cpu \ + python3 ../../eval_lm.py $lmdatadir \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file done fi - train_feat=$train_feat_dir/feats.scp train_token_text=data/$train_set/token_text valid_feat=$valid_feat_dir/feats.scp @@ -268,26 +257,26 @@ if [ $stage -le 6 ]; then log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ - --log-interval 1500 --log-format simple --print-training-sample-interval 2000 --ddp-backend "no_c10d" \ - --num-workers 0 --max-tokens 26000 --max-sentences 24 \ - --valid-subset $valid_subset --max-sentences-valid 48 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ - --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ + --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ + --num-workers 0 --max-tokens 26000 --max-sentences 48 \ + --valid-subset $valid_subset --max-sentences-valid 64 \ + --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 --ddp-backend no_c10d \ + --max-epoch 40 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ - --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ + --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_with_wer \ --label-smoothing 0.1 --smoothing-type uniform \ - --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ + --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6 \ --train-feat-files $train_feat --train-text-files $train_token_text \ --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi - if [ $stage -le 7 ]; then echo "Stage 7: Decoding" + [ ! -d $KALDI_ROOT ] && echo "Expected $KALDI_ROOT to exist" && exit 1; opts="" path=$dir/$checkpoint decode_affix= @@ -298,21 +287,31 @@ if [ $stage -le 7 ]; then fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $test_sets; do - feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp - text=data/$dataset/token_text + decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} + # only score train_dev with built-in scorer + text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --max-tokens 16000 --max-sentences 24 --num-shards 1 --shard-id 0 \ - --test-feat-files $feat --test-text-files $text \ + --max-tokens 24000 --max-sentences 48 --num-shards 1 --shard-id 0 \ + --test-feat-files ${dumpdir}/$dataset/delta${do_delta}/feats.scp $text_opt \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ - --path $path --beam 30 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ - --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ + --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ + --results-path $decode_dir $opts \ 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + + echo "Scoring with kaldi..." + local/score.sh data/$dataset $decode_dir + if [ "$dataset" == "train_dev" ]; then + echo -n "tran_dev: " && cat $decode_dir/scoring/wer | grep WER + elif [ "$dataset" == "eval2000" ] || [ "$dataset" == "rt03" ]; then + echo -n "$dataset: " && grep Sum $decode_dir/scoring/$dataset.ctm.filt.sys | \ + awk '{print "WER="$11"%, Sub="$8"%, Ins="$10"%, Del="$9"%"}' | tee $decode_dir/wer + echo -n "swbd subset: " && grep Sum $decode_dir/scoring/$dataset.ctm.swbd.filt.sys | \ + awk '{print "WER="$11"%, Sub="$8"%, Ins="$10"%, Del="$9"%"}' | tee $decode_dir/wer_swbd + subset=callhm && [ "$dataset" == "rt03" ] && subset=fsh + echo -n "$subset subset: " && grep Sum $decode_dir/scoring/$dataset.ctm.$subset.filt.sys | \ + awk '{print "WER="$11"%, Sub="$8"%, Ins="$10"%, Del="$9"%"}' | tee $decode_dir/wer_$subset + echo "WERs saved in $decode_dir/wer*" + fi done - if $kaldi_scoring; then - echo "verify WER by scoring with kaldi..." - # The word level results are stored in decode_dir/decoded_results.txt - local/score_sclite.sh data/eval2000 $dir/decode_eval2000${decode_affix:+_${decode_affix}} - local/score_sclite.sh data/rt03 $dir/decode_rt03${decode_affix:+_${decode_affix}} - fi fi diff --git a/examples/asr_wsj/local/score.sh b/examples/asr_wsj/local/score.sh index 80653d43d..1746f1afd 100755 --- a/examples/asr_wsj/local/score.sh +++ b/examples/asr_wsj/local/score.sh @@ -1,11 +1,8 @@ #!/bin/bash -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# Copyright (c) 2012-2014, Johns Hopkins University (Author: Daniel Povey, Yenda Trmal) +# 2019-present, Yiming Wang +# Apache 2.0 # begin configuration section. @@ -43,3 +40,4 @@ $cmd $dir/scoring_kaldi/log/score.log \ compute-wer --text --mode=present \ ark:$dir/scoring_kaldi/test_filt.txt ark,p:- ">&" $dir/scoring_kaldi/wer || exit 1; +exit 0 diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh index 80ed44a7f..19308a610 100644 --- a/examples/asr_wsj/path.sh +++ b/examples/asr_wsj/path.sh @@ -13,4 +13,3 @@ export PATH=~/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH export PYTHONUNBUFFERED=1 - diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 26c78e5d9..661abc7e6 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -223,10 +223,8 @@ def _decode(self, tokens, model, encoder_out, incremental_states): decoder_out[0] = decoder_out[0][:, -1:, :] attn = decoder_out[1] if type(attn) is dict: - attn = attn['attn'] + attn = attn.get('attn', None) if attn is not None: - if type(attn) is dict: - attn = attn['attn'] attn = attn[:, -1, :] probs = model.get_normalized_probs(decoder_out, log_probs=True) probs = probs[:, -1, :] diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 8e00f0bce..0c29c7862 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -17,6 +17,36 @@ from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, + smoothing_type='uniform', prob_mask=None, unigram_tensor=None): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + if smoothing_type == 'temporal': + assert torch.is_tensor(prob_mask) + smooth_loss = -lprobs.mul(prob_mask).sum(-1, keepdim=True) + elif smoothing_type == 'unigram': + assert torch.is_tensor(unigram_tensor) + smooth_loss = -lprobs.matmul(unigram_tensor.to(lprobs)) + elif smoothing_type == 'uniform': + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + else: + raise ValueError('Unsupported smoothing type: {}'.format(smoothing_type)) + if ignore_index is not None: + non_pad_mask = target.ne(ignore_index) + nll_loss = nll_loss[non_pad_mask] + smooth_loss = smooth_loss[non_pad_mask] + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = epsilon / lprobs.size(-1) if smoothing_type == 'uniform' else epsilon + loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + @register_criterion('label_smoothed_cross_entropy_with_wer') class LabelSmoothedCrossEntropyWithWERCriterion(LabelSmoothedCrossEntropyCriterion): @@ -30,6 +60,7 @@ def __init__(self, args, task): self.valid_tgt_dataset = None self.num_updates = -1 self.epoch = 0 + self.unigram_tensor = None if args.smoothing_type == 'unigram': self.unigram_tensor = torch.cuda.FloatTensor(dict.count).unsqueeze(-1) \ if torch.cuda.is_available() and not args.cpu \ @@ -188,6 +219,7 @@ def forward(self, model, sample, reduce=True): print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends + prob_mask = None if self.args.smoothing_type == 'temporal': # see https://arxiv.org/pdf/1612.02695.pdf # prob_mask.dtype=int for deterministic behavior of Tensor.scatter_add_() @@ -206,26 +238,16 @@ def forward(self, model, sample, reduce=True): prob_mask[:, :, self.padding_idx] = 0 # clear cumulative count on prob_mask = prob_mask.float() # convert to float sum_prob = prob_mask.sum(-1, keepdim=True) - sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # deal with "divided by 0" problem + sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1)) lprobs = lprobs.view(-1, lprobs.size(-1)) target = target.view(-1, 1) - non_pad_mask = target.ne(self.padding_idx) - nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask] - if self.args.smoothing_type == 'temporal': - smooth_loss = -lprobs.mul(prob_mask).sum(-1, keepdim=True)[non_pad_mask] - elif self.args.smoothing_type == 'unigram': - smooth_loss = -lprobs.matmul(self.unigram_tensor.to(lprobs))[non_pad_mask] - elif self.args.smoothing_type == 'uniform': - smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask] - else: - raise ValueError('Unsupported smoothing type: {}'.format(self.args.smoothing_type)) - if reduce: - nll_loss = nll_loss.sum() - smooth_loss = smooth_loss.sum() - eps_i = self.eps / lprobs.size(-1) if self.args.smoothing_type == 'uniform' else self.eps - loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss + loss, nll_loss = label_smoothed_nll_loss( + lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, + smoothing_type=self.args.smoothing_type, prob_mask=prob_mask, + unigram_tensor=self.unigram_tensor, + ) sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, @@ -263,10 +285,8 @@ def _decode(self, tokens, model, encoder_out, incremental_states): decoder_out[0] = decoder_out[0][:, -1:, :] attn = decoder_out[1] if type(attn) is dict: - attn = attn['attn'] + attn = attn.get('attn', None) if attn is not None: - if type(attn) is dict: - attn = attn['attn'] attn = attn[:, -1, :] probs = model.get_normalized_probs(decoder_out, log_probs=True) probs = probs[:, -1, :] diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 4cf1b1d58..650505cee 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -418,6 +418,7 @@ def output_lengths(self, in_lengths): def forward(self, src_tokens, src_lengths): if self.left_pad: + # nn.utils.rnn.pack_padded_sequence requires right-padding; # convert left-padding to right-padding src_tokens = speech_utils.convert_padding_direction( src_tokens, @@ -775,6 +776,17 @@ def lstm_lm_librispeech(args): base_lm_architecture(args) +@register_model_architecture('lstm_lm', 'lstm_lm_swbd') +def lstm_lm_swbd(args): + args.dropout = getattr(args, 'dropout', 0.3) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1800) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1800) + args.decoder_layers = getattr(args, 'decoder_layers', 3) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1800) + args.share_embed = getattr(args, 'share_embed', True) + base_lm_architecture(args) + + @register_model_architecture('lstm_lm', 'lstm_wordlm_wsj') def lstm_wordlm_wsj(args): args.dropout = getattr(args, 'dropout', 0.35) @@ -837,3 +849,18 @@ def speech_conv_lstm_librispeech(args): args.attention_type = getattr(args, 'attention_type', 'bahdanau') args.attention_dim = getattr(args, 'attention_dim', 512) base_architecture(args) + + +@register_model_architecture('speech_lstm', 'speech_conv_lstm_swbd') +def speech_conv_lstm_swbd(args): + args.dropout = getattr(args, 'dropout', 0.5) + args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 600) + args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 600) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 600) + args.decoder_layers = getattr(args, 'decoder_layers', 3) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1800) + args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', True) + args.attention_type = getattr(args, 'attention_type', 'bahdanau') + args.attention_dim = getattr(args, 'attention_dim', 600) + base_architecture(args) diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index c5c733ca9..cfffbfd39 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -237,6 +237,7 @@ def build_generator(self, args): unk_penalty=getattr(args, 'unkpen', 0), sampling=getattr(args, 'sampling', False), sampling_topk=getattr(args, 'sampling_topk', -1), + sampling_topp=getattr(args, 'sampling_topp', -1.0), temperature=getattr(args, 'temperature', 1.), diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1), diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5), diff --git a/speech_recognize.py b/speech_recognize.py index c294bfdee..70b9d4492 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -139,7 +139,7 @@ def main(args): print('T-{}\t{}'.format(utt_id, target_sent)) # Process top predictions - for j, hypo in enumerate(hypos[i][:min(len(hypos), args.nbest)]): + for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_str = dict.string(hypo['tokens'].int().cpu()) # not removing bpe at this point if not args.quiet or i == 0: hypo_sent = dict.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 01978d0f9..1029cc9d3 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -189,6 +189,11 @@ def aligned_print(ref, hyp, steps): assert isinstance(ref, list) and isinstance(hyp, list) assert isinstance(steps, list) + if len(steps) == 0: # in case both ref and hyp are empty + assert len(ref) == 0 and len(hyp) == 0 + out_str = 'REF: \nHYP: \nSTP: \nWER: {:.2f}%\n\n'.format(0.) + return out_str + out_str = 'REF: ' for i in range(len(steps)): delim = ' ' if i < len(steps) - 1 else '\n' From 236c927ab0bb6a0f47bd95ee4d65322822b388a5 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 21 Jul 2019 01:09:59 -0400 Subject: [PATCH 028/119] revise speech_{transformer/fconv} code --- fairseq/models/speech_fconv.py | 15 ++------------- fairseq/models/speech_transformer.py | 25 +++---------------------- 2 files changed, 5 insertions(+), 35 deletions(-) diff --git a/fairseq/models/speech_fconv.py b/fairseq/models/speech_fconv.py index b8e387e92..e5b5f0ad5 100644 --- a/fairseq/models/speech_fconv.py +++ b/fairseq/models/speech_fconv.py @@ -150,8 +150,8 @@ class SpeechFConvEncoder(FConvEncoder): Args: conv_layers_before (~fairseq.speech_lstm.ConvBNReLU): convolutions befoe fconv layers - input_size (int, optional): dim of input to the transformer before being - projected to embed_dim + input_size (int, optional): dimension of the input to the transformer + before being projected to embed_dim embed_dim (int, optional): embedding dimension max_positions (int, optional): maximum supported input sequence length convolutions (list, optional): the convolutional layer structure. Each @@ -295,17 +295,6 @@ def forward(self, src_tokens, src_lengths): 'encoder_padding_mask': encoder_padding_mask, # B x T } - def reorder_encoder_out(self, encoder_out, new_order): - if encoder_out['encoder_out'] is not None: - encoder_out['encoder_out'] = ( - encoder_out['encoder_out'][0].index_select(0, new_order), - encoder_out['encoder_out'][1].index_select(0, new_order), - ) - if encoder_out['encoder_padding_mask'] is not None: - encoder_out['encoder_padding_mask'] = \ - encoder_out['encoder_padding_mask'].index_select(0, new_order) - return encoder_out - def max_positions(self): """Maximum input length supported by the encoder.""" return int(1e5) diff --git a/fairseq/models/speech_transformer.py b/fairseq/models/speech_transformer.py index e73143d0a..bf8398310 100644 --- a/fairseq/models/speech_transformer.py +++ b/fairseq/models/speech_transformer.py @@ -162,10 +162,10 @@ class SpeechTransformerEncoder(TransformerEncoder): Args: args (argparse.Namespace): parsed command-line arguments - conv_layers_before (~fairseq.speech_lstm.ConvBNReLU): convolutions befoe + conv_layers_before (~fairseq.speech_lstm.ConvBNReLU): convolutions before transformer layers - input_size (int, optional): dim of input to the transformer before being - projected to args.encoder_embed_dim + input_size (int, optional): dimension of the input to the transformer + before being projected to args.encoder_embed_dim """ def __init__(self, args, conv_layers_before=None, input_size=83): @@ -239,25 +239,6 @@ def forward(self, src_tokens, src_lengths): 'encoder_padding_mask': encoder_padding_mask, # B x T } - def reorder_encoder_out(self, encoder_out, new_order): - """ - Reorder encoder output according to *new_order*. - - Args: - encoder_out: output from the ``forward()`` method - new_order (LongTensor): desired order - - Returns: - *encoder_out* rearranged according to *new_order* - """ - if encoder_out['encoder_out'] is not None: - encoder_out['encoder_out'] = \ - encoder_out['encoder_out'].index_select(1, new_order) - if encoder_out['encoder_padding_mask'] is not None: - encoder_out['encoder_padding_mask'] = \ - encoder_out['encoder_padding_mask'].index_select(0, new_order) - return encoder_out - def max_positions(self): """Maximum input length supported by the encoder.""" return self.max_source_positions From 3b38acdba54873a46b888d6579d0913e71045386 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 25 Jul 2019 23:19:38 -0400 Subject: [PATCH 029/119] improve swbd recipe; comtinue training while lr is no less than --min-lr; code adaptation/changes according to the commits from Jul 24, 2019 to Aug 1, 2019 --- examples/asr_swbd/run.sh | 22 +++++++-------- fairseq/data/speech_dataset.py | 11 ++++---- fairseq/models/speech_lstm.py | 31 +++++++++++++-------- fairseq/tasks/speech_recognition.py | 1 - speech_train.py | 43 +++++++++++++++++++++-------- 5 files changed, 67 insertions(+), 41 deletions(-) diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index a99b0c094..2d9038870 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -18,7 +18,7 @@ free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically affix= train_set=train_nodup valid_set=train_dev -test_sets="train_dev eval2000 rt03" +test_set="train_dev eval2000 rt03" checkpoint=checkpoint_best.pt # LM related @@ -118,7 +118,7 @@ if [ $stage -le 1 ]; then data/$train_set/feats.scp data/$train_set/cmvn.ark exp/dump_feats/train $train_feat_dir dump.sh --cmd "$train_cmd" --nj 10 --do_delta $do_delta \ data/$valid_set/feats.scp data/$train_set/cmvn.ark exp/dump_feats/dev $valid_feat_dir - for rtask in $test_sets; do + for rtask in $test_set; do test_feat_dir=$dumpdir/$rtask/delta${do_delta}; mkdir -p $test_feat_dir dump.sh --cmd "$train_cmd" --nj 10 --do_delta $do_delta \ data/$rtask/feats.scp data/$train_set/cmvn.ark exp/dump_feats/recog/$rtask \ @@ -164,7 +164,7 @@ if [ $stage -le 2 ]; then --user_defined_symbols=$(cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "\[" | tr "\n" "," | sed 's/,$//') echo "Tokenizing text for train/valid/test sets..." - for dataset in $train_set $test_sets; do # validation is included in tests + for dataset in $train_set $test_set; do # validation is included in tests text=data/$dataset/text token_text=data/$dataset/token_text spm_encode --model=${sentencepiece_model}.model --output_format=piece \ @@ -173,7 +173,7 @@ if [ $stage -le 2 ]; then echo "Preparing text for subword LM..." mkdir -p $lmdatadir - for dataset in $train_set $test_sets; do + for dataset in $train_set $test_set; do token_text=data/$dataset/token_text cut -f 2- -d" " $token_text > $lmdatadir/$dataset.tokens done @@ -193,7 +193,7 @@ lmdict=$dict if [ $stage -le 3 ]; then echo "Stage 3: Text Binarization for subword LM Training" mkdir -p $lmdatadir/logs - for dataset in $test_sets; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done + for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/logs/preprocess.log \ python3 ../../preprocess.py --task language_modeling_for_asr \ @@ -232,9 +232,9 @@ fi if [ $stage -le 5 ]; then echo "Stage 5: subword LM Evaluation" gen_set_array=(test) - num=$(echo $test_sets | awk '{print NF-1}') + num=$(echo $test_set | awk '{print NF-1}') for i in $(seq $num); do gen_set_array[$i]="test$i"; done #gen_set_array=(test test1 test2) - test_set_array=($test_sets) + test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log python3 ../../eval_lm.py $lmdatadir \ @@ -258,10 +258,10 @@ if [ $stage -le 6 ]; then [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ - --num-workers 0 --max-tokens 26000 --max-sentences 48 \ + --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 --ddp-backend no_c10d \ - --max-epoch 40 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ + --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ @@ -282,11 +282,11 @@ if [ $stage -le 7 ]; then decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.3 --coverage-weight 0.0" + opts="$opts --lm-weight 0.25 --coverage-weight 0.0" decode_affix=shallow_fusion fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" - for dataset in $test_sets; do + for dataset in $test_set; do decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} # only score train_dev with built-in scorer text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index 699b96619..1d497f0f0 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -99,8 +99,7 @@ class SpeechDataset(FairseqDataset): shuffle (bool, optional): shuffle dataset elements before batching (default: True) input_feeding (bool, optional): create a shifted version of the targets - to be passed into the model for input feeding/teacher forcing - (default: True) + to be passed into the model for teacher forcing (default: True). """ def __init__( @@ -177,10 +176,10 @@ def collater(self, samples): - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - `prev_output_tokens` (LongTensor): a padded 2D Tensor of - tokens in the target sentence, shifted right by one position - for input feeding/teacher forcing, of shape `(bsz, - tgt_len)`. This key will not be present if *input_feeding* - is ``False``. Padding will appear on the left if + tokens in the target sentence, shifted right by one + position for teacher forcing, of shape `(bsz, tgt_len)`. + This key will not be present if *input_feeding* is + ``False``. Padding will appear on the left if *left_pad_target* is ``True``. - `target` (LongTensor): a padded 2D Tensor of tokens in the diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 650505cee..34e964d6b 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -18,9 +18,15 @@ register_model, register_model_architecture, ) +from fairseq.models.lstm import ( + AttentionLayer, + Embedding, + LSTM, + LSTMCell, + Linear, +) from fairseq.modules import AdaptiveSoftmax, speech_attention - -from .lstm import AttentionLayer, Embedding, LSTM, LSTMCell, Linear +from fairseq.tasks.speech_recognition import SpeechRecognitionTask import speech_tools.utils as speech_utils @@ -287,9 +293,12 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): utils.print_embed_overlap(embed_dict, dictionary) return utils.load_embedding(embed_dict, dictionary, embed_tokens) - dictionary = task.word_dictionary \ - if args.is_wordlm and hasattr(task, 'word_dictionary') \ - else task.target_dictionary + if args.is_wordlm and hasattr(task, 'word_dictionary'): + dictionary = task.word_dictionary + elif isinstance(task, SpeechRecognitionTask): + dictionary = task.target_dictionary + else: + dictionary = task.source_dictionary # separate decoder input embeddings pretrained_decoder_embed = None @@ -854,13 +863,13 @@ def speech_conv_lstm_librispeech(args): @register_model_architecture('speech_lstm', 'speech_conv_lstm_swbd') def speech_conv_lstm_swbd(args): args.dropout = getattr(args, 'dropout', 0.5) - args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 600) - args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 600) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 600) + args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 640) + args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 4) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 640) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 640) args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1800) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1920) args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', True) args.attention_type = getattr(args, 'attention_type', 'bahdanau') - args.attention_dim = getattr(args, 'attention_dim', 600) + args.attention_dim = getattr(args, 'attention_dim', 640) base_architecture(args) diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index cfffbfd39..a713faf74 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -231,7 +231,6 @@ def build_generator(self, args): max_len_a=getattr(args, 'max_len_a', 0), max_len_b=getattr(args, 'max_len_b', 200), min_len=getattr(args, 'min_len', 1), - stop_early=(not getattr(args, 'no_early_stop', False)), normalize_scores=(not getattr(args, 'unnormalized', False)), len_penalty=getattr(args, 'lenpen', 1), unk_penalty=getattr(args, 'unkpen', 0), diff --git a/speech_train.py b/speech_train.py index 11edcc00b..d6ad18a73 100755 --- a/speech_train.py +++ b/speech_train.py @@ -12,7 +12,6 @@ import collections import math -import os import random import torch @@ -37,6 +36,9 @@ def main(args, init_distributed=False): if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) + if distributed_utils.is_master(args): + checkpoint_utils.verify_checkpoint_directory(args.save_dir) + # Print args print(args) @@ -81,7 +83,7 @@ def main(args, init_distributed=False): valid_losses = [None] valid_subsets = args.valid_subset.split(',') while ( - (lr > args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) + (lr >= args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) and ( epoch_itr.epoch < max_epoch or ( epoch_itr.epoch == max_epoch @@ -145,7 +147,7 @@ def train(args, trainer, task, epoch_itr): for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue # these are already logged above - if 'loss' in k: + if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) @@ -264,21 +266,20 @@ def validate(args, trainer, task, epoch_itr, subsets): extra_meters[k].update(v) # log validation stats - stats = get_valid_stats(trainer) + stats = get_valid_stats(trainer, args, extra_meters) for k, meter in extra_meters.items(): - stats[k] = meter if k == 'wer' or k == 'cer' else meter.avg - if hasattr(checkpoint_utils.save_checkpoint, 'best'): - stats['best_' + args.best_checkpoint_metric] = min( - checkpoint_utils.save_checkpoint.best, - stats[args.best_checkpoint_metric].avg, - ) + stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[args.best_checkpoint_metric].avg) + valid_losses.append( + stats[args.best_checkpoint_metric].avg + if args.best_checkpoint_metric == 'loss' + else stats[args.best_checkpoint_metric] + ) return valid_losses -def get_valid_stats(trainer): +def get_valid_stats(trainer, args, extra_meters=None): stats = collections.OrderedDict() stats['loss'] = trainer.get_meter('valid_loss') if trainer.get_meter('valid_nll_loss').count > 0: @@ -288,6 +289,24 @@ def get_valid_stats(trainer): nll_loss = stats['loss'] stats['ppl'] = utils.get_perplexity(nll_loss.avg) stats['num_updates'] = trainer.get_num_updates() + if hasattr(checkpoint_utils.save_checkpoint, 'best'): + key = 'best_{0}'.format(args.best_checkpoint_metric) + best_function = max if args.maximize_best_checkpoint_metric else min + + current_metric = None + if args.best_checkpoint_metric == 'loss': + current_metric = stats['loss'].avg + elif args.best_checkpoint_metric in extra_meters: + current_metric = extra_meters[args.best_checkpoint_metric].avg + elif args.best_checkpoint_metric in stats: + current_metric = stats[args.best_checkpoint_metric] + else: + raise ValueError("best_checkpoint_metric not found in logs") + + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, + current_metric, + ) return stats From 833ad033134ead56fd58e35465bb719e2b3e5815 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 2 Aug 2019 00:50:04 -0400 Subject: [PATCH 030/119] Relicense Espresso under MIT license --- examples/asr_librispeech/run.sh | 9 +++------ examples/asr_swbd/local/prepare_ctm.py | 9 +++------ examples/asr_swbd/run.sh | 10 +++------- examples/asr_wsj/run.sh | 9 +++------ fairseq/criterions/cross_entropy_with_wer.py | 8 +++----- .../label_smoothed_cross_entropy_with_wer.py | 8 +++----- fairseq/data/scp_dataset.py | 8 +++----- fairseq/data/speech_dataset.py | 8 +++----- fairseq/data/token_dictionary.py | 8 +++----- fairseq/models/external_language_model.py | 8 +++----- fairseq/models/speech_fconv.py | 8 +++----- fairseq/models/speech_lstm.py | 8 +++----- fairseq/models/speech_transformer.py | 8 +++----- fairseq/modules/speech_attention.py | 8 +++----- fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py | 8 +++----- fairseq/tasks/language_modeling_for_asr.py | 8 +++----- fairseq/tasks/speech_recognition.py | 8 +++----- fairseq/wer.py | 8 +++----- speech_recognize.py | 10 ++++------ speech_tools/compute_wer.py | 8 +++----- speech_tools/text2token.py | 8 +++----- speech_tools/text2vocabulary.py | 8 +++----- speech_tools/utils.py | 8 +++----- speech_train.py | 10 ++++------ tests/test_speech_dataset.py | 8 +++----- tests/test_speech_utils.py | 8 +++----- 26 files changed, 80 insertions(+), 137 deletions(-) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 28846bc7c..7c22523d4 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -1,11 +1,8 @@ #!/bin/bash - -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. set -e -o pipefail diff --git a/examples/asr_swbd/local/prepare_ctm.py b/examples/asr_swbd/local/prepare_ctm.py index 21eea3742..b1119220d 100755 --- a/examples/asr_swbd/local/prepare_ctm.py +++ b/examples/asr_swbd/local/prepare_ctm.py @@ -1,11 +1,8 @@ #!/usr/bin/env python3 -# CopyRight (c) 2019-present, Hang Lyu -# 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Hang Lyu, Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. # This script is use to transform the word level results to ctm format # The start_time and end_time of each word is fake diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 2d9038870..156c22450 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -1,12 +1,8 @@ #!/bin/bash - -# Copyright (c) 2019-present, Hang Lyu -# 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Hang Lyu, Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. set -e -o pipefail diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index eedc5194d..6fd47aa07 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -1,11 +1,8 @@ #!/bin/bash - -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. set -e -o pipefail diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index 661abc7e6..fad3d104e 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import numpy as np import torch diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 0c29c7862..1ad541ead 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -1,9 +1,7 @@ -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import numpy as np import torch diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index a6fc3e444..2f9836e5c 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import os diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index 1d497f0f0..9c7a02e87 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import numpy as np import torch diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index a91d89d09..dff3177d5 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch diff --git a/fairseq/models/external_language_model.py b/fairseq/models/external_language_model.py index cd807b21d..6e4288063 100644 --- a/fairseq/models/external_language_model.py +++ b/fairseq/models/external_language_model.py @@ -1,9 +1,7 @@ -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import math import torch diff --git a/fairseq/models/speech_fconv.py b/fairseq/models/speech_fconv.py index e5b5f0ad5..64a28ae5e 100644 --- a/fairseq/models/speech_fconv.py +++ b/fairseq/models/speech_fconv.py @@ -1,9 +1,7 @@ -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import math import torch diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 34e964d6b..5cd9f4f62 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch import torch.nn as nn diff --git a/fairseq/models/speech_transformer.py b/fairseq/models/speech_transformer.py index bf8398310..91996b752 100644 --- a/fairseq/models/speech_transformer.py +++ b/fairseq/models/speech_transformer.py @@ -1,9 +1,7 @@ -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import math diff --git a/fairseq/modules/speech_attention.py b/fairseq/modules/speech_attention.py index d8d4e5524..114fa1aca 100644 --- a/fairseq/modules/speech_attention.py +++ b/fairseq/modules/speech_attention.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import math import torch diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index c9cdabf7a..6348a313f 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -1,9 +1,7 @@ -# Copyright (c) 2019-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch.optim.lr_scheduler diff --git a/fairseq/tasks/language_modeling_for_asr.py b/fairseq/tasks/language_modeling_for_asr.py index 5270257ff..1f94bc0d4 100644 --- a/fairseq/tasks/language_modeling_for_asr.py +++ b/fairseq/tasks/language_modeling_for_asr.py @@ -1,9 +1,7 @@ -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index a713faf74..6b9cafbf6 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch diff --git a/fairseq/wer.py b/fairseq/wer.py index c98bc30d3..52aa2e5b6 100644 --- a/fairseq/wer.py +++ b/fairseq/wer.py @@ -1,9 +1,7 @@ -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import re diff --git a/speech_recognize.py b/speech_recognize.py index 70b9d4492..422bcaa4e 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -1,11 +1,9 @@ #!/usr/bin/env python3 -# Copyright (c) 2017-present, Facebook, Inc. -# 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. """ Recognize pre-processed speech with a trained model. """ diff --git a/speech_tools/compute_wer.py b/speech_tools/compute_wer.py index e0ae5257c..05003fb9c 100755 --- a/speech_tools/compute_wer.py +++ b/speech_tools/compute_wer.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 - -# Copyright (c) 2019-present, Yiming Wang +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import argparse import sys, re diff --git a/speech_tools/text2token.py b/speech_tools/text2token.py index 6b8a201af..38f58e874 100755 --- a/speech_tools/text2token.py +++ b/speech_tools/text2token.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 - -# Copyright (c) 2019-present, Yiming Wang +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import argparse import sys diff --git a/speech_tools/text2vocabulary.py b/speech_tools/text2vocabulary.py index 3b43e1ce6..9168c470c 100755 --- a/speech_tools/text2vocabulary.py +++ b/speech_tools/text2vocabulary.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 - -# Copyright (c) 2019-present, Yiming Wang +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import argparse import sys diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 1029cc9d3..1db622fec 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import os, re import numpy as np diff --git a/speech_train.py b/speech_train.py index d6ad18a73..156f8da16 100755 --- a/speech_train.py +++ b/speech_train.py @@ -1,11 +1,9 @@ #!/usr/bin/env python3 -# Copyright (c) 2017-present, Facebook, Inc. -# 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. """ Train a new model on one or across multiple GPUs. """ diff --git a/tests/test_speech_dataset.py b/tests/test_speech_dataset.py index 7e1f0bc0e..062c83891 100644 --- a/tests/test_speech_dataset.py +++ b/tests/test_speech_dataset.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import unittest import string diff --git a/tests/test_speech_utils.py b/tests/test_speech_utils.py index 4e66b5f96..d3aa0945b 100644 --- a/tests/test_speech_utils.py +++ b/tests/test_speech_utils.py @@ -1,9 +1,7 @@ -# Copyright (c) 2018-present, Yiming Wang -# All rights reserved. +# Copyright (c) Yiming Wang # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import unittest import string From 2777ac7eeb5b5f6ba9df8daf2f25b1f0772948f1 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 2 Aug 2019 19:51:45 -0400 Subject: [PATCH 031/119] add --eos-factor for beam search to alleviate the problem of too short transcripts with LM fusion --- examples/asr_librispeech/run.sh | 2 +- examples/asr_wsj/run.sh | 2 +- fairseq/tasks/speech_recognition.py | 1 + speech_recognize.py | 3 +++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 7c22523d4..978c24340 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -229,7 +229,7 @@ if [ ${stage} -le 8 ]; then decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.4 --coverage-weight 0.015" + opts="$opts --lm-weight 0.4 --coverage-weight 0.0 --eos-factor 1.5" decode_affix=shallow_fusion fi for dataset in $test_set; do diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 6fd47aa07..ffa5d4fec 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -294,7 +294,7 @@ if [ ${stage} -le 9 ]; then decode_affix=shallow_fusion else path="$path:$wordlmdir/$lm_checkpoint" - opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-8 --coverage-weight 0.01" + opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-8 --coverage-weight 0.005 --eos-factor 1.5" decode_affix=shallow_fusion_wordlm fi fi diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 6b9cafbf6..10700daef 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -241,6 +241,7 @@ def build_generator(self, args): match_source_len=getattr(args, 'match_source_len', False), no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), coverage_weight=getattr(args, 'coverage_weight', 0.0), + eos_factor=getattr(args, 'eos_factor', None), ) def build_dataset_for_inference(self, src_tokens, src_lengths): diff --git a/speech_recognize.py b/speech_recognize.py index 422bcaa4e..79877cc05 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -223,6 +223,9 @@ def cli_main(): help='coverage weight in log-prob space, mostly to ' 'reduce deletion errors while using the pretrained ' 'external LM for decoding') + parser.add_argument('--eos-factor', default=None, type=float, metavar='F', + help='only consider emitting EOS if its score is no less ' + 'than the specified factor of the best candidate score') parser.add_argument('--lm-weight', default=0.0, type=float, metavar='W', help='LM weight in log-prob space, assuming the pretrained ' 'external LM is specified as the second one in --path') From 5cf1c41bfb39ba85b97399ebd9db52acc7d48d8d Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 7 Aug 2019 13:24:43 -0400 Subject: [PATCH 032/119] tokenize each sentence such that it ends with ; modify look-ahead LM accordingly; modify the WSJ recipe accordingly --- examples/asr_wsj/run.sh | 8 +-- fairseq/models/external_language_model.py | 81 +++++++++-------------- fairseq/models/speech_lstm.py | 2 +- speech_tools/text2token.py | 13 +++- 4 files changed, 48 insertions(+), 56 deletions(-) diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index ffa5d4fec..897da4d77 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -225,7 +225,7 @@ if [ ${stage} -le 6 ] && $use_wordlm; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 \ --task language_modeling_for_asr --dict $wordlmdict \ --log-interval 2000 --log-format simple \ - --num-workers 0 --max-tokens 6300 --max-sentences 256 \ + --num-workers 0 --max-tokens 6400 --max-sentences 256 \ --valid-subset $valid_subset --max-sentences-valid 512 \ --distributed-world-size $ngpus --distributed-port 100 \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ @@ -266,10 +266,10 @@ if [ ${stage} -le 8 ]; then [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ - --num-workers 0 --max-tokens 24000 --max-sentences 32 \ + --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 \ --distributed-world-size $ngpus --distributed-port 100 --ddp-backend no_c10d \ - --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 \ + --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ @@ -294,7 +294,7 @@ if [ ${stage} -le 9 ]; then decode_affix=shallow_fusion else path="$path:$wordlmdir/$lm_checkpoint" - opts="$opts --word-dict $wordlmdict --lm-weight 0.8 --oov-penalty 1e-8 --coverage-weight 0.005 --eos-factor 1.5" + opts="$opts --word-dict $wordlmdict --lm-weight 0.9 --oov-penalty 1e-7 --coverage-weight 0.0 --eos-factor 1.5" decode_affix=shallow_fusion_wordlm fi fi diff --git a/fairseq/models/external_language_model.py b/fairseq/models/external_language_model.py index 6e4288063..3d82f24a6 100644 --- a/fairseq/models/external_language_model.py +++ b/fairseq/models/external_language_model.py @@ -41,7 +41,8 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): class _LookAheadWordLanguageModelDecoder(FairseqIncrementalDecoder): """Look-ahead word language model decoder for end-to-end ASR. It is intended to be used for beam search decoding. See https://arxiv.org/abs/1808.02608 - for details. + for details. We modify the original algorithm a little bit to adapt it to + the case where each tokenized sentence ends with before . """ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): super().__init__(wordlm.decoder.dictionary) @@ -98,9 +99,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): assert (prev_output_tokens == self.subword_eos_idx).all(), \ 'expecting the input to the first time step to be ' w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) - lm_out = self.lm_decoder(w, incremental_state=incremental_state) - cumsum_probs = torch.cumsum(self.lm_decoder.get_normalized_probs( - lm_out, log_probs=False, sample=None), dim=-1) # B x 1 x V + lm_probs = self.lm_decoder.get_normalized_probs( + self.lm_decoder(w, incremental_state=incremental_state), + log_probs=False, sample=None) # B x 1 x V + cumsum_probs = torch.cumsum(lm_probs, dim=-1) # B x 1 x V nodes = [self.lexroot] * bsz else: cumsum_probs = utils.get_incremental_state( @@ -112,16 +114,15 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): self.word_unk_idx for node in nodes ]).unsqueeze(-1) # B x 1 old_cached_state = _clone_cached_state(cached_state) - lm_out = self.lm_decoder(w, incremental_state=incremental_state) - self.lm_decoder.masked_copy_incremental_state(incremental_state, - old_cached_state, batch_space_mask) # recompute cumsum_probs from inter-word transition probabilities # only for those whose prev_output_token is - cumsum_probs[batch_space_mask] = torch.cumsum( - self.lm_decoder.get_normalized_probs(lm_out, log_probs=False, - sample=None), - dim=-1, - )[batch_space_mask] + lm_probs = self.lm_decoder.get_normalized_probs( + self.lm_decoder(w, incremental_state=incremental_state), + log_probs=False, sample=None) # B x 1 x V + self.lm_decoder.masked_copy_incremental_state(incremental_state, + old_cached_state, batch_space_mask) # restore those not masked + cumsum_probs[batch_space_mask] = \ + torch.cumsum(lm_probs, dim=-1)[batch_space_mask] tokens_list = prev_output_tokens.squeeze(-1).tolist() for i in range(bsz): if tokens_list[i] == self.subword_space_idx: @@ -144,12 +145,13 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): cumsum_probs[:, :, self.word_unk_idx] - \ cumsum_probs[:, :, self.word_unk_idx - 1] ).unsqueeze(-1).repeat(1, 1, self.subword_vocab_size) - # set the probability of emitting or to 0 if - # prev_output_tokens is or + # set the probability of emitting to 0 if prev_output_tokens + # is or , and that of emitting to 0 if + # prev_output_tokens is not batch_space_eos_mask = batch_space_mask | \ prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) out_probs[batch_space_eos_mask, :, self.subword_space_idx] = self.zero - out_probs[batch_space_eos_mask, :, self.subword_eos_idx] = self.zero + out_probs[~batch_space_mask, :, self.subword_eos_idx] = self.zero # set transition probability to 1 for those whose node is out of the # tree, i.e. node is None (case 4 in Eqn. 15) batch_node_none_mask = [] @@ -206,7 +208,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_probs.scatter_(-1, subword_idx, cumsum_probs_children) out_probs[:, :, self.subword_pad_idx] = self.zero - # apply word-level probabilies for and (case 1 in Eqn. 15) + # apply word-level probabilies for (case 1 in Eqn. 15) word_idx, batch_node_word_end_mask = [], [] for node in nodes: if node is not None and node.word_idx >= 0: @@ -225,26 +227,13 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): ).squeeze(-1).div_(sum_probs[batch_node_word_end_mask]), ) # b x 1 out_probs[batch_node_word_end_mask, :, self.subword_space_idx] = word_probs - out_probs[batch_node_word_end_mask, :, self.subword_eos_idx] = word_probs # take log of probs and clip it from below to avoid log(0) out_logprobs = torch.max(out_probs, out_probs.new([self.zero])).log_() - # add log-probs of emitting word to that of emitting subword - cached_state = _clone_cached_state(utils.get_incremental_state( - self.lm_decoder, incremental_state, 'cached_state')) # for restore later - w = prev_output_tokens.new([ - node.word_idx if node is not None and node.word_idx >= 0 else \ - self.word_unk_idx for node in nodes - ]).unsqueeze(-1) # B x 1 - word_eos_logprobs = self.lm_decoder.get_normalized_probs( - self.lm_decoder(w, incremental_state=incremental_state), - log_probs=True, - sample=None, - )[:, :, self.word_eos_idx] - utils.set_incremental_state(self.lm_decoder, incremental_state, - 'cached_state', cached_state) # restore decoder's state - out_logprobs[:, :, self.subword_eos_idx] += word_eos_logprobs + # assign log-probs of emitting word to that of emitting subword + out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ + lm_probs.log_()[batch_space_mask, :, self.word_eos_idx] # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in @@ -291,7 +280,9 @@ def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, class _MultiLevelLanguageModel(FairseqIncrementalDecoder): """Multi-level (subword/word) language model decoder for end-to-end ASR. It is intended to be used for beam search decoding. - See https://ieeexplore.ieee.org/document/8268948 for details. + See https://ieeexplore.ieee.org/document/8268948 for details. We modify the + original algorithm a little bit to adapt it to the case where each tokenized + sentence ends with before . """ def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True): @@ -380,7 +371,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): sample=None, )[batch_space_mask] self.wordlm_decoder.masked_copy_incremental_state(incremental_state, - old_wordlm_cached_state, batch_space_mask) + old_wordlm_cached_state, batch_space_mask) # restore those not masked tokens_list = prev_output_tokens.squeeze(-1).tolist() token_idx, batch_is_child_mask = [], [] @@ -427,7 +418,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): subword_cumlogprobs) utils.set_incremental_state(self, incremental_state, 'nodes', nodes) - # apply word-level probabilies for emitting or + # apply word-level probabilies for emitting w = prev_output_tokens.new([ node.word_idx if node is not None and node.word_idx >= 0 else \ self.word_unk_idx for node in nodes @@ -437,29 +428,21 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): word_logprobs += torch.where(batch_word_end_mask, -subword_cumlogprobs, word_logprobs.new([self.log_oov_penalty])) out_logprobs[:, :, self.subword_space_idx] = word_logprobs - out_logprobs[:, :, self.subword_eos_idx] = word_logprobs - # set the probability of emitting or to 0 if - # prev_output_tokens is or + # set the probability of emitting to 0 if prev_output_tokens is + # or , and that of emitting to 0 if prev_output_tokens + # is not batch_space_eos_mask = batch_space_mask | \ prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) out_logprobs[batch_space_eos_mask, :, self.subword_space_idx] = self.logzero - out_logprobs[batch_space_eos_mask, :, self.subword_eos_idx] = self.logzero + out_logprobs[~batch_space_mask, :, self.subword_eos_idx] = self.logzero # add log-probs of emitting word to that of emitting subword - cached_state = _clone_cached_state(utils.get_incremental_state( - self.wordlm_decoder, incremental_state, 'cached_state')) # for restore later - word_eos_logprobs = self.wordlm_decoder.get_normalized_probs( - self.wordlm_decoder(w, incremental_state=incremental_state), - log_probs=True, - sample=None, - )[:, :, self.word_eos_idx] - out_logprobs[:, :, self.subword_eos_idx] += word_eos_logprobs + out_logprobs[batch_space_mask, :, self.subword_eos_idx] += \ + wordlm_logprobs[batch_space_mask, :, self.word_eos_idx] utils.set_incremental_state(self, incremental_state, 'out_logprobs', out_logprobs) - utils.set_incremental_state(self.wordlm_decoder, incremental_state, - 'cached_state', cached_state) # restore decoder's state # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 5cd9f4f62..0f4440f16 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -808,7 +808,7 @@ def lstm_wordlm_wsj(args): @register_model_architecture('speech_lstm', 'speech_lstm') def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.3) + args.dropout = getattr(args, 'dropout', 0.4) args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', '[64, 64, 128, 128]') args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', diff --git a/speech_tools/text2token.py b/speech_tools/text2token.py index 38f58e874..9db179a08 100755 --- a/speech_tools/text2token.py +++ b/speech_tools/text2token.py @@ -18,6 +18,9 @@ def get_parser(): help='skip first n columns') parser.add_argument('--space', default='', type=str, help='space symbol') + parser.add_argument('--endswithspace', default=True, type=bool, + help='Whether to append to the end of each ' + 'tokenized sentence.') parser.add_argument('--non-lang-syms', default=None, type=str, help='path to a file listing non-linguistic symbols, ' 'e.g., etc. One entry per line.') @@ -39,9 +42,15 @@ def main(args): tokenized = tokenize(' '.join(entry[args.skip_ncols:]), space=args.space, non_lang_syms=nls) if args.skip_ncols > 0: - print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) + if args.endswithspace: + print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized + ' ' + args.space) + else: + print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) else: - print(tokenized) + if args.endswithspace: + print(tokenized + ' ' + args.space) + else: + print(tokenized) if __name__ == '__main__': From b2d3ef204fb03dda3c97315995bd9549be250eed Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 20 Aug 2019 15:39:04 -0400 Subject: [PATCH 033/119] switch to pip install sentencepiece; modify Librispeech/SWBD recipes accordingly; code adaptation/changes according to the commits on Aug 21 and 30, 2019 --- examples/asr_librispeech/run.sh | 17 ++++----- examples/asr_swbd/run.sh | 42 ++++++++++------------ fairseq/models/speech_lstm.py | 2 +- fairseq/tasks/language_modeling_for_asr.py | 28 ++++++++------- speech_tools/Makefile | 9 ++--- speech_train.py | 3 +- 6 files changed, 47 insertions(+), 54 deletions(-) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 978c24340..599cee685 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -107,7 +107,7 @@ if [ ${stage} -le 3 ]; then mkdir -p data/lang cut -f 2- -d" " data/${train_set}/text > data/lang/input echo "$0: training sentencepiece model..." - spm_train --bos_id=-1 --pad_id=0 --eos_id=1 --unk_id=2 --input=data/lang/input \ + python3 ../../scripts/spm_train.py --bos_id=-1 --pad_id=0 --eos_id=1 --unk_id=2 --input=data/lang/input \ --vocab_size=$((sentencepiece_vocabsize+3)) --character_coverage=1.0 \ --model_type=$sentencepiece_type --model_prefix=$sentencepiece_model \ --input_sentence_size=10000000 @@ -115,8 +115,9 @@ if [ ${stage} -le 3 ]; then for dataset in $train_set $valid_set $test_set; do text=data/$dataset/text token_text=data/$dataset/token_text - spm_encode --model=${sentencepiece_model}.model --output_format=piece \ - <(cut -f 2- -d" " $text) | paste -d" " <(cut -f 1 -d" " $text) - > $token_text + cut -f 2- -d" " $text | \ + python3 ../../scripts/spm_encode.py --model=${sentencepiece_model}.model --output_format=piece | \ + paste -d" " <(cut -f 1 -d" " $text) - > $token_text if [ "$dataset" == "$train_set" ]; then cut -f 2- -d" " $token_text | tr ' ' '\n' | sort | uniq -c | \ awk '{print $2,$1}' | sort > $dict @@ -135,7 +136,7 @@ if [ ${stage} -le 3 ]; then fi echo "$0: preparing extra corpus for subword LM training..." zcat $lmdatadir/librispeech-lm-norm.txt.gz | \ - spm_encode --model=${sentencepiece_model}.model --output_format=piece | \ + python3 ../../scripts/spm_encode.py --model=${sentencepiece_model}.model --output_format=piece | \ cat $lmdatadir/$train_set.tokens - > $lmdatadir/train.tokens fi @@ -169,7 +170,7 @@ if [ ${stage} -le 5 ]; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval 8000 --log-format simple \ - --num-workers 0 --max-tokens 30720 --max-sentences 1024 \ + --num-workers 0 --max-tokens 30720 --max-sentences 1024 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ --distributed-world-size $ngpus --distributed-port 100 \ --max-epoch 30 --optimizer adam --lr 0.001 --clip-norm 1.0 \ @@ -206,10 +207,10 @@ if [ ${stage} -le 7 ]; then [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ - --num-workers 0 --max-tokens 26000 --max-sentences 24 \ + --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 48 \ --distributed-world-size $ngpus --distributed-port 100 --ddp-backend no_c10d \ - --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ + --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ @@ -229,7 +230,7 @@ if [ ${stage} -le 8 ]; then decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.4 --coverage-weight 0.0 --eos-factor 1.5" + opts="$opts --lm-weight 0.45 --coverage-weight 0.0 --eos-factor 1.5" decode_affix=shallow_fusion fi for dataset in $test_set; do diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 156c22450..971cc0efc 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -132,12 +132,12 @@ if [ $stage -le 2 ]; then mkdir -p data/lang mkdir -p $lmdatadir - echo "Making a non-linguistic symbol list..." + echo "$0: making a non-linguistic symbol list..." train_text=data/$train_set/text cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "\[" > $nlsyms cat $nlsyms - echo "Preparing extra corpus for subword LM training..." + echo "$0: preparing extra corpus for subword LM training..." if [ -f $lmdatadir/fisher_text0 ]; then rm -rf $lmdatadir/fisher_text0 fi @@ -145,42 +145,36 @@ if [ $stage -le 2 ]; then [ ! -d $x/data/trans ] \ && "Cannot find transcripts in Fisher directory $x" && exit 1; cat $x/data/trans/*/*.txt | \ - grep -v '^#' | grep -v '^$' | cut -d' ' -f4- >> $lmdatadir/fisher_text0 + grep -v "^#" | grep -v "^$" | cut -d" " -f4- >> $lmdatadir/fisher_text0 done cat $lmdatadir/fisher_text0 | local/fisher_map_words.pl | \ sed 's/^[ \t]*//'> $lmdatadir/fisher_text - echo "Training sentencepiece model..." - cut -f 2- -d" " data/$train_set/text | \ - cat - $lmdatadir/fisher_text > data/lang/input - spm_train --bos_id=-1 --pad_id=0 --eos_id=1 --unk_id=2 --input=data/lang/input \ + echo "$0: training sentencepiece model..." + cut -f 2- -d" " data/$train_set/text | cat - $lmdatadir/fisher_text > data/lang/input + python3 ../../scripts/spm_train.py --bos_id=-1 --pad_id=0 --eos_id=1 --unk_id=2 --input=data/lang/input \ --vocab_size=$((sentencepiece_vocabsize+3)) --character_coverage=1.0 \ --model_type=$sentencepiece_type --model_prefix=$sentencepiece_model \ --input_sentence_size=10000000 \ - --user_defined_symbols=$(cut -f 2- $train_text | tr " " "\n" | sort | uniq | grep "\[" | tr "\n" "," | sed 's/,$//') + --user_defined_symbols=$(cat $nlsyms | tr "\n" "," | sed 's/,$//') - echo "Tokenizing text for train/valid/test sets..." + echo "$0: tokenizing text for train/valid/test sets..." for dataset in $train_set $test_set; do # validation is included in tests text=data/$dataset/text token_text=data/$dataset/token_text - spm_encode --model=${sentencepiece_model}.model --output_format=piece \ - <(cut -f 2- -d' ' $text) | paste -d" " <(cut -f 1 -d' ' $text) - > $token_text - done - - echo "Preparing text for subword LM..." - mkdir -p $lmdatadir - for dataset in $train_set $test_set; do - token_text=data/$dataset/token_text + cut -f 2- -d" " $text | \ + python3 ../../scripts/spm_encode.py --model=${sentencepiece_model}.model --output_format=piece | \ + paste -d" " <(cut -f 1 -d" " $text) - > $token_text cut -f 2- -d" " $token_text > $lmdatadir/$dataset.tokens done - echo "Preparing extra corpus for subword LM training..." - cat $lmdatadir/fisher_text |\ - spm_encode --model=${sentencepiece_model}.model --output_format=piece |\ + echo "$0: tokenizing extra corpus for subword LM training..." + cat $lmdatadir/fisher_text | \ + python3 ../../scripts/spm_encode.py --model=${sentencepiece_model}.model --output_format=piece | \ cat $lmdatadir/$train_set.tokens - > $lmdatadir/train.tokens - echo "Making a dictionary with swbd+fisher text" - cat $lmdatadir/train.tokens | tr " " "\n" | grep -v -e '^\s*$' | sort | \ + echo "$0: making a dictionary with swbd+fisher text" + cat $lmdatadir/train.tokens | tr " " "\n" | grep -v -e "^\s*$" | sort | \ uniq -c | awk '{print $2,$1}' > $dict wc -l $dict fi @@ -197,7 +191,7 @@ if [ $stage -le 3 ]; then --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ --testpref $test_paths \ - --destdir $lmdatadir + --destdir $lmdatadir fi [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) || \ @@ -215,7 +209,7 @@ if [ $stage -le 4 ]; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval 500 --log-format simple \ - --num-workers 0 --max-tokens 30720 --max-sentences 1024 \ + --num-workers 0 --max-tokens 25600 --max-sentences 1024 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 0f4440f16..3856838ec 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -847,7 +847,7 @@ def conv_lstm_wsj(args): def speech_conv_lstm_librispeech(args): args.dropout = getattr(args, 'dropout', 0.3) args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 1024) - args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) + args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 4) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) args.decoder_layers = getattr(args, 'decoder_layers', 3) diff --git a/fairseq/tasks/language_modeling_for_asr.py b/fairseq/tasks/language_modeling_for_asr.py index 1f94bc0d4..d06f8e0e1 100644 --- a/fairseq/tasks/language_modeling_for_asr.py +++ b/fairseq/tasks/language_modeling_for_asr.py @@ -14,7 +14,7 @@ from .language_modeling import LanguageModelingTask -@register_task('language_modeling_for_asr') +@register_task("language_modeling_for_asr") class LanguageModelingForASRTask(LanguageModelingTask): """ Train a language model. @@ -96,29 +96,31 @@ def setup_task(cls, args, **kwargs): dictionary = None output_dictionary = None if args.data: - paths = args.data.split(':') + paths = args.data.split(":") assert len(paths) > 0 - dict_path = os.path.join(paths[0], 'dict.txt') if args.dict is None \ + dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \ else args.dict dictionary = TokenDictionary.load(dict_path) - print('| dictionary: {} types'.format(len(dictionary))) + print("| dictionary: {} types".format(len(dictionary))) output_dictionary = dictionary if args.output_dictionary_size >= 0: - output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size) + output_dictionary = TruncatedDictionary( + dictionary, args.output_dictionary_size + ) # upgrade old checkpoints - if hasattr(args, 'exclude_self_target'): + if hasattr(args, "exclude_self_target"): args.self_target = not args.exclude_self_target targets = [] - if getattr(args, 'self_target', False): - targets.append('self') - if getattr(args, 'future_target', False): - targets.append('future') - if getattr(args, 'past_target', False): - targets.append('past') + if getattr(args, "self_target", False): + targets.append("self") + if getattr(args, "future_target", False): + targets.append("future") + if getattr(args, "past_target", False): + targets.append("past") if len(targets) == 0: # standard language modeling - targets = ['future'] + targets = ["future"] return cls(args, dictionary, output_dictionary, targets=targets) diff --git a/speech_tools/Makefile b/speech_tools/Makefile index b94b4140a..14a2d21d2 100644 --- a/speech_tools/Makefile +++ b/speech_tools/Makefile @@ -2,18 +2,13 @@ KALDI = .PHONY: all clean -all: kaldi kaldi-io-for-python sentencepiece +all: kaldi kaldi-io-for-python kaldi-io-for-python: rm -rf kaldi-io-for-python git clone https://github.com/vesis84/kaldi-io-for-python.git ln -sf kaldi-io-for-python/kaldi_io/kaldi_io.py kaldi_io.py -sentencepiece: - rm -rf sentencepiece - git clone https://github.com/google/sentencepiece.git - cd sentencepiece && git checkout v0.1.82 && mkdir build && cd build && (cmake3 .. || cmake ..) && $(MAKE) - ifneq ($(strip $(KALDI)),) kaldi: ln -s $(KALDI) kaldi @@ -25,4 +20,4 @@ kaldi: endif clean: - rm -rf kaldi kaldi-io-for-python kaldi_io.py sentencepiece + rm -rf kaldi kaldi-io-for-python kaldi_io.py diff --git a/speech_train.py b/speech_train.py index 156f8da16..c240e8da4 100755 --- a/speech_train.py +++ b/speech_train.py @@ -12,6 +12,7 @@ import math import random +import numpy as np import torch from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils @@ -30,6 +31,7 @@ def main(args, init_distributed=False): # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) + np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) @@ -78,7 +80,6 @@ def main(args, init_distributed=False): lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() - valid_losses = [None] valid_subsets = args.valid_subset.split(',') while ( (lr >= args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) From 1c19fe5e3d6a06b5c74f0349ce72b67e2f5ab452 Mon Sep 17 00:00:00 2001 From: Tongfei Chen Date: Sun, 8 Sep 2019 18:46:20 -0400 Subject: [PATCH 034/119] update tensorized tree implementation --- fairseq/models/external_language_model.py | 21 +- .../tensorized_lookahead_language_model.py | 257 ++++++++++++++++++ speech_recognize.py | 5 +- speech_tools/tensorized_prefix_tree.py | 103 +++++++ speech_tools/utils.py | 8 +- 5 files changed, 385 insertions(+), 9 deletions(-) create mode 100644 fairseq/models/tensorized_lookahead_language_model.py create mode 100644 speech_tools/tensorized_prefix_tree.py diff --git a/fairseq/models/external_language_model.py b/fairseq/models/external_language_model.py index 3d82f24a6..a1d44f6de 100644 --- a/fairseq/models/external_language_model.py +++ b/fairseq/models/external_language_model.py @@ -28,9 +28,18 @@ def clone_state(state): return tuple(map(clone_state, cached_state)) -class LookAheadWordLanguageModel(FairseqLanguageModel): - """A :class:`fairseq.models.FairseqLanguageModel` wrapper for - :class:`_LookAheadWordLanguageModelDecoder`. +class RawOutExternalLanguageModelBase(FairseqLanguageModel): + """Base class for all external language models for ASR whose raw forward output + will be directly used by the caller (rather than, for example, doing normalization + by the caller). + """ + def __init__(self, decoder): + super().__init__(decoder) + + +class LookAheadWordLanguageModel(RawOutExternalLanguageModelBase): + """A :class:`fairseq.models.external_language_model.RawOutExternalLanguageModelBase` + wrapper for :class:`_LookAheadWordLanguageModelDecoder`. """ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): decoder = _LookAheadWordLanguageModelDecoder(wordlm, subword_dict, @@ -266,9 +275,9 @@ def max_positions(self): return int(1e5) # an arbitrary large number -class MultiLevelLanguageModel(FairseqLanguageModel): - """A :class:`fairseq.models.FairseqLanguageModel` wrapper for - :class:`_MultiLevelLanguageModel`. +class MultiLevelLanguageModel(RawOutExternalLanguageModelBase): + """A :class:`fairseq.external_language_model.RawOutExternalLanguageModelBase` + wrapper for :class:`_MultiLevelLanguageModel`. """ def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True): diff --git a/fairseq/models/tensorized_lookahead_language_model.py b/fairseq/models/tensorized_lookahead_language_model.py new file mode 100644 index 000000000..86280c221 --- /dev/null +++ b/fairseq/models/tensorized_lookahead_language_model.py @@ -0,0 +1,257 @@ +# Copyright (c) Tongfei Chen, Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import * +import torch + +from fairseq.models import FairseqLanguageModel, FairseqIncrementalDecoder +from fairseq.models.external_language_model import RawOutExternalLanguageModelBase +from fairseq.data import TokenDictionary +from fairseq import utils + +from speech_tools.tensorized_prefix_tree import TensorizedPrefixTree +from speech_tools.utils import tokenize + + +def _clone_cached_state(cached_state): + if cached_state is None: + return None + + def clone_state(state): + if isinstance(state, list): + return [clone_state(state_i) for state_i in state] + return state.clone() if state is not None else None + + return tuple(map(clone_state, cached_state)) + + +class TensorizedLookaheadLanguageModel(RawOutExternalLanguageModelBase): + """A :class:`fairseq.models.external_language_model.RawOutExternalLanguageModelBase` + wrapper for :class:`_TensorizedLookaheadLanguageModelDecoder`. + """ + def __init__(self, + word_lm: FairseqLanguageModel, + subword_dict: TokenDictionary, + oov_penalty: float = 1e-4, + open_vocab: bool = True + ): + decoder = _TensorizedLookaheadLanguageModelDecoder(word_lm, subword_dict, oov_penalty, open_vocab) + super().__init__(decoder) + + @classmethod + def build_model(cls, args, task): + raise NotImplementedError + + +class _TensorizedLookaheadLanguageModelDecoder(FairseqIncrementalDecoder): + """Look-ahead word language model decoder for end-to-end ASR. It is intended + to be used for beam search decoding. See https://arxiv.org/abs/1808.02608 + for details. We modify the original algorithm a little bit to adapt it to + the case where each tokenized sentence ends with before . + """ + def __init__(self, + word_lm: FairseqLanguageModel, + subword_dict: TokenDictionary, + oov_penalty: float = 1e-4, + open_vocab: bool = True): + super().__init__(word_lm.decoder.dictionary) + + self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder + assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ + callable(self.lm_decoder.masked_copy_incremental_state), \ + 'The wrapped decoder should implement masked_copy_incremental_state()' + + self.oov_penalty = oov_penalty + self.open_vocab = open_vocab + self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue + + word_dict: TokenDictionary = self.lm_decoder.dictionary + self.word_pad_idx = word_dict.pad() + self.word_eos_idx = word_dict.eos() + self.word_unk_idx = word_dict.unk() + + self.subword_space_idx = subword_dict.space() + self.subword_pad_idx = subword_dict.pad() + self.subword_eos_idx = subword_dict.eos() + self.subword_vocab_size = len(subword_dict) + + tokenizer: Callable[[str], List[str]] = \ + lambda x: tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + self.tree = TensorizedPrefixTree.build(word_dict, subword_dict, tokenizer) + + assert self.tree.max_out_degree() <= self.subword_vocab_size + + @torch.no_grad() + def forward(self, + prev_output_tokens: torch.Tensor, # Z_Tokens[Batch, SeqLength] + encoder_out=None, + incremental_state: Dict[str, Any] = None): + assert incremental_state is not None, 'This model is for incremental decoding only' + prev_output_tokens = prev_output_tokens[:, -1:] # Z_Tokens[Batch, Len=1] + bsz = prev_output_tokens.size(0) + + if prev_output_tokens.device != self.tree.word_idx.device: + self.tree.to_cuda(device=prev_output_tokens.device) + + # Move the batched state to the next state according to the automaton + batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) # B[Batch] + cached_state = utils.get_incremental_state(self.lm_decoder, incremental_state, 'cached_state') + + if cached_state is None: # First step + assert (prev_output_tokens == self.subword_eos_idx).all(), \ + 'expecting the input to the first time step to be ' + w: torch.Tensor = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) # Z[Batch, Len=1] + lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( + self.lm_decoder(w, incremental_state=incremental_state), + log_probs=False, sample=None) # R[Batch, 1, Vocab] + cumsum_probs: torch.Tensor = lm_probs.cumsum(dim=-1) # R[Batch, 1, Vocab] + nodes: torch.Tensor = prev_output_tokens.new_full([bsz], self.tree.root_id) # Z_NodeId[Batch] + all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] + + else: # Not the first step + cumsum_probs: torch.Tensor = utils.get_incremental_state( + self, incremental_state, 'cumsum_probs') # R[Batch, 1, Vocab] + nodes: torch.Tensor = utils.get_incremental_state(self, incremental_state, 'nodes') # Z_NodeId[Batch] + assert nodes.size(0) == bsz + w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(1) # Z[Batch, Len=1] + w[w < 0] = self.word_unk_idx + + old_cached_state = _clone_cached_state(cached_state) + # recompute cumsum_probs from inter-word transition probabilities + # only for those whose prev_output_token is + lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( + self.lm_decoder(w, incremental_state=incremental_state), + log_probs=False, sample=None) # R[Batch, 1, Vocab] + self.lm_decoder.masked_copy_incremental_state( + incremental_state, old_cached_state, batch_space_mask) # restore those not masked + cumsum_probs[batch_space_mask] = lm_probs.cumsum(dim=-1)[batch_space_mask] + + prev_all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] + prev_possible_tokens = self.tree.prev_subword_idx[prev_all_children] # Z[Batch, PossibleChildren] + # intra-word transition: go to child; oov transition: go to "None" node + mask = prev_possible_tokens.eq(prev_output_tokens.expand_as(prev_possible_tokens)) + nodes: torch.Tensor = (prev_all_children * mask.long()).sum(dim=1) # Z[Batch] + # inter-word transition: go back to root + nodes[batch_space_mask] = self.tree.root_id # Z[Batch] + all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] + + utils.set_incremental_state(self, incremental_state, 'cumsum_probs', cumsum_probs) + utils.set_incremental_state(self, incremental_state, 'nodes', nodes) + + # Compute probabilities + # initialize out_probs [Batch, 1, Vocab] + if self.open_vocab: + # set out_probs to oov_penalty * P(|h) (case 3 in Eqn. 15) + out_probs = self.oov_penalty * ( + cumsum_probs[:, :, self.word_unk_idx] - \ + cumsum_probs[:, :, self.word_unk_idx - 1] + ).unsqueeze(-1).repeat(1, 1, self.subword_vocab_size) + + # set the probability of emitting to 0 if prev_output_tokens + # is or , and that of emitting to 0 if + # prev_output_tokens is not + batch_space_eos_mask = batch_space_mask | \ + prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + out_probs[batch_space_eos_mask, :, self.subword_space_idx] = self.zero + out_probs[~batch_space_mask, :, self.subword_eos_idx] = self.zero + + # set transition probability to 1 for those whose node is out of the + # tree, i.e. node is None (case 4 in Eqn. 15) + batch_node_none_mask = nodes.eq(self.tree.none_id) # B[Batch] + out_probs[batch_node_none_mask] = 1. + else: + # set out_probs to 0 + out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], self.zero) + + # compute parent probabilities for those whose node is not None + left_ranges = self.tree.word_set_idx[nodes, 0] # Z[Batch] + right_ranges = self.tree.word_set_idx[nodes, 1] # Z[Batch] + batch_node_not_root_mask = nodes.ne(self.tree.none_id) & nodes.ne(self.tree.root_id) # B[Batch] + sum_probs = torch.where( + batch_node_not_root_mask, + (cumsum_probs.squeeze(1).gather(-1, right_ranges.unsqueeze(-1)) - + cumsum_probs.squeeze(1).gather(-1, left_ranges.unsqueeze(-1))).squeeze(-1), + cumsum_probs.new([1.0]) + ) # R[Batch] + + # compute transition probabilities to child nodes (case 2 in Eqn. 15) + left_ranges_of_all_children = self.tree.word_set_idx[all_children, 0] # Z[Batch, PossibleChildren] + right_ranges_of_all_children = self.tree.word_set_idx[all_children, 1] # Z[Batch, PossibleChildren] + cumsum_probs_of_all_children = ( + cumsum_probs.squeeze(1).gather(-1, right_ranges_of_all_children) - + cumsum_probs.squeeze(1).gather(-1, left_ranges_of_all_children) + ).unsqueeze(1) / sum_probs.unsqueeze(-1).unsqueeze(-1) # R[Batch, 1, PossibleChildren] + cumsum_probs_of_all_children[sum_probs < self.zero, :, :] = self.zero + next_possible_tokens = self.tree.prev_subword_idx[all_children] # Z[Batch, PossibleChildren] + out_probs.scatter_( + -1, + next_possible_tokens.unsqueeze(1), + cumsum_probs_of_all_children + ) + # assume self.subword_pad_idx is the padding index in self.tree.prev_subword_idx + out_probs[:, :, self.subword_pad_idx] = self.zero + + # apply word-level probabilities for (case 1 in Eqn. 15) + word_idx = self.tree.word_idx[nodes] # Z[Batch] + batch_node_word_end_mask = word_idx.ge(0) # B[Batch] + # get rid of -1's (word idx of root or non-terminal states). It doesn't + # matter what the "dummy" index it would be replaced with (as it will + # always be masked out by batch_node_word_end_mask), as long as it is > 0 + word_idx[word_idx < 0] = 1 + word_probs = torch.where( + sum_probs < self.zero, + cumsum_probs.new([self.zero]), + ( + cumsum_probs.squeeze(1).gather(-1, word_idx.unsqueeze(-1)) - \ + cumsum_probs.squeeze(1).gather(-1, word_idx.unsqueeze(-1) - 1) + ).squeeze(-1) / sum_probs + ) # R[Batch] + out_probs[batch_node_word_end_mask, 0, self.subword_space_idx] = \ + word_probs[batch_node_word_end_mask] + + # take log of probs and clip it from below to avoid log(0) + out_logprobs = torch.log(torch.max(out_probs, out_probs.new([self.zero]))) + + # assign log-probs of emitting word to that of emitting subword + out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ + torch.log(lm_probs)[batch_space_mask, :, self.word_eos_idx] + + utils.set_incremental_state(self, incremental_state, 'out_logprobs', out_logprobs) + + # note that here we return log-probs rather than logits, and the second + # element is None, which is usually a tensor of attention weights in + # attention-based models + return out_logprobs, None + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + + cumsum_probs = utils.get_incremental_state( + self, incremental_state, 'cumsum_probs') + if cumsum_probs is not None: + new_cumsum_probs = cumsum_probs.index_select(0, new_order) + utils.set_incremental_state(self, incremental_state, 'cumsum_probs', + new_cumsum_probs) + + nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + if nodes is not None: + new_nodes = nodes.index_select(0, new_order) + utils.set_incremental_state(self, incremental_state, 'nodes', + new_nodes) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + # in-place op as not being used for backprop + return net_output[0] if log_probs else net_output[0].exp_() + + def max_positions(self): + return int(1e5) # an arbitrary large number + + def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): + pass + + def output_layer(self, features, **kwargs): + pass + diff --git a/speech_recognize.py b/speech_recognize.py index 79877cc05..93f378f58 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -15,7 +15,8 @@ from fairseq import wer, checkpoint_utils, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel -from fairseq.models.external_language_model import LookAheadWordLanguageModel, MultiLevelLanguageModel +from fairseq.models.external_language_model import MultiLevelLanguageModel +from fairseq.models.tensorized_lookahead_language_model import TensorizedLookaheadLanguageModel from fairseq.utils import import_user_module from speech_tools.utils import plot_attention @@ -58,7 +59,7 @@ def main(args): del models[i] print('| LM fusion with Multi-level LM') else: - models[i] = LookAheadWordLanguageModel(m, dict, + models[i] = TensorizedLookaheadLanguageModel(m, dict, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab) print('| LM fusion with Look-ahead Word LM') diff --git a/speech_tools/tensorized_prefix_tree.py b/speech_tools/tensorized_prefix_tree.py new file mode 100644 index 000000000..33d6f3728 --- /dev/null +++ b/speech_tools/tensorized_prefix_tree.py @@ -0,0 +1,103 @@ +# Copyright (c) Tongfei Chen, Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os, re +import numpy as np +import torch + +from typing import * +from fairseq.data import TokenDictionary +from speech_tools.utils import lexical_prefix_tree + + +class TensorizedPrefixTree: + """ + A tensorized lexical prefix tree designed for maximum parallelism in ASR decoding. + """ + + def __init__(self, + children: np.ndarray, # NodeId[NodeId, NumChildren] + prev_subword_idx: np.ndarray, # SubWordId[NodeId] + word_idx: np.ndarray, # WordId[NodeId]; -1 means non-terminal node + word_set_idx: np.ndarray # WordId[NodeId, 2 = (first-1, last)] + ): + self.children = children + self.prev_subword_idx = prev_subword_idx + self.word_idx = word_idx + self.word_set_idx = word_set_idx + + self.none_id = 0 + self.root_id = 1 + + def max_out_degree(self) -> int: + return self.children.shape[1] + + def to_cuda(self, device): + self.children = self.children.to(device=device) + self.prev_subword_idx = self.prev_subword_idx.to(device=device) + self.word_idx = self.word_idx.to(device=device) + self.word_set_idx = self.word_set_idx.to(device=device) + + @staticmethod + def build( + word_dict: TokenDictionary, + subword_dict: TokenDictionary, + subword_tokenizer: Callable[[str], List[str]] = None + ): + """ + Builds a tensorized lexical prefix tree for words. + """ + + root = lexical_prefix_tree( + word_dict=word_dict, + subword_dict=subword_dict, + subword_tokenizer=subword_tokenizer + ) # build traditional tree data structure by reusing existing routines + + # Performs pre-order traversal of this tree to assign an index for each node + max_num_children = 0 + nodes = [None] # nodes[0] is a dummy node for OOV + node_to_id_dict = {} + stack = [root] + + while len(stack) > 0: + curr = stack.pop() + node_id = len(nodes) + nodes.append(curr) + node_to_id_dict[curr] = node_id + if len(curr.children) > max_num_children: + max_num_children = len(curr.children) + + # Guarantee that the children are traversed ascendingly according to the subword index + for _, next_node in sorted(curr.children.items(), key=lambda t: t[0], reverse=True): + stack.append(next_node) + + # Construct the tree + num_nodes = len(nodes) + children = np.full([num_nodes, max_num_children], 0, dtype=np.int64) + prev_subword_idx = np.full([num_nodes], subword_dict.pad(), dtype=np.int64) + word_idx = np.full([num_nodes], -1, dtype=np.int64) + word_set_idx = np.full([num_nodes, 2], word_dict.pad(), dtype=np.int64) + + for node_id in range(1, len(nodes)): # skip 0, which is `None` + node = nodes[node_id] + # Guarantee that the children are traversed ascendingly according to the subword index + for i, (subword_id, child) in enumerate(sorted(node.children.items(), key=lambda t: t[0])): + child_node_id = node_to_id_dict[child] + children[node_id, i] = child_node_id + prev_subword_idx[child_node_id] = subword_id + + word_idx[node_id] = node.word_idx + if node.word_set is not None: + word_set_idx[node_id] = node.word_set + else: + word_set_idx[node_id] = [0, len(word_dict) - 1] + + return TensorizedPrefixTree( + children=torch.from_numpy(children), + prev_subword_idx=torch.from_numpy(prev_subword_idx), + word_idx=torch.from_numpy(word_idx), + word_set_idx=torch.from_numpy(word_set_idx) + ) diff --git a/speech_tools/utils.py b/speech_tools/utils.py index 1db622fec..d5a4bad74 100644 --- a/speech_tools/utils.py +++ b/speech_tools/utils.py @@ -6,10 +6,12 @@ import os, re import numpy as np from collections import Counter +from typing import Callable, List import torch from fairseq import utils +from fairseq.data import TokenDictionary def tokenize(sent, space='', non_lang_syms=None): @@ -257,7 +259,11 @@ def aligned_print(ref, hyp, steps): return out_str -def lexical_prefix_tree(word_dict, subword_dict, subword_tokenizer=None): +def lexical_prefix_tree( + word_dict: TokenDictionary, + subword_dict: TokenDictionary, + subword_tokenizer: Callable[[str], List[str]] = None +): """Build a lexical prefix tree for words. Args: From f21c428686802046b810127fff45963dfcc81924 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 10 Sep 2019 23:26:00 -0400 Subject: [PATCH 035/119] switch to pip install kaldi_io --- fairseq/data/scp_dataset.py | 5 ++++- speech_tools/.gitignore | 3 --- speech_tools/Makefile | 9 ++------- speech_tools/text2token.py | 6 +++--- tests/test_speech_dataset.py | 5 ++++- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/fairseq/data/scp_dataset.py b/fairseq/data/scp_dataset.py index 2f9836e5c..02a07a3b1 100644 --- a/fairseq/data/scp_dataset.py +++ b/fairseq/data/scp_dataset.py @@ -8,7 +8,10 @@ import numpy as np import torch -import speech_tools.kaldi_io as kaldi_io +try: + import kaldi_io +except ImportError: + raise ImportError('Please install kaldi_io with: pip install kaldi_io') class ScpDataset(torch.utils.data.Dataset): diff --git a/speech_tools/.gitignore b/speech_tools/.gitignore index b86428794..6fc07e5f7 100644 --- a/speech_tools/.gitignore +++ b/speech_tools/.gitignore @@ -1,4 +1 @@ kaldi -kaldi-io-for-python -kaldi_io.py -sentencepiece diff --git a/speech_tools/Makefile b/speech_tools/Makefile index 14a2d21d2..e5fec0150 100644 --- a/speech_tools/Makefile +++ b/speech_tools/Makefile @@ -2,12 +2,7 @@ KALDI = .PHONY: all clean -all: kaldi kaldi-io-for-python - -kaldi-io-for-python: - rm -rf kaldi-io-for-python - git clone https://github.com/vesis84/kaldi-io-for-python.git - ln -sf kaldi-io-for-python/kaldi_io/kaldi_io.py kaldi_io.py +all: kaldi ifneq ($(strip $(KALDI)),) kaldi: @@ -20,4 +15,4 @@ kaldi: endif clean: - rm -rf kaldi kaldi-io-for-python kaldi_io.py + rm -rf kaldi diff --git a/speech_tools/text2token.py b/speech_tools/text2token.py index 9db179a08..e020ba26f 100755 --- a/speech_tools/text2token.py +++ b/speech_tools/text2token.py @@ -18,7 +18,7 @@ def get_parser(): help='skip first n columns') parser.add_argument('--space', default='', type=str, help='space symbol') - parser.add_argument('--endswithspace', default=True, type=bool, + parser.add_argument('--ends-with-space', default=True, type=bool, help='Whether to append to the end of each ' 'tokenized sentence.') parser.add_argument('--non-lang-syms', default=None, type=str, @@ -42,12 +42,12 @@ def main(args): tokenized = tokenize(' '.join(entry[args.skip_ncols:]), space=args.space, non_lang_syms=nls) if args.skip_ncols > 0: - if args.endswithspace: + if args.ends_with_space: print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized + ' ' + args.space) else: print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) else: - if args.endswithspace: + if args.ends_with_space: print(tokenized + ' ' + args.space) else: print(tokenized) diff --git a/tests/test_speech_dataset.py b/tests/test_speech_dataset.py index 062c83891..33f4a8529 100644 --- a/tests/test_speech_dataset.py +++ b/tests/test_speech_dataset.py @@ -14,7 +14,10 @@ SpeechDataset, TokenDictionary, TokenTextDataset, ScpCachedDataset, ScpInMemoryDataset) -import speech_tools.kaldi_io as kaldi_io +try: + import kaldi_io +except ImportError: + raise ImportError('Please install kaldi_io with: pip install kaldi_io') class TestSpeechDataset(unittest.TestCase): From bac7415f45403caac12d880c00bf5fd9aef2a2b2 Mon Sep 17 00:00:00 2001 From: Tongfei Chen Date: Sun, 15 Sep 2019 22:39:46 -0400 Subject: [PATCH 036/119] Update README.md; add logo; slightly change LM weight and beam size for Librispeech; code adaptation/changes according to the commits on Sep 17, 2019 --- LICENSE | 4 +- README.md | 219 ++++--------------- README_fairseq.md | 216 ++++++++++++++++++ espresso_logo.png | Bin 0 -> 13344 bytes examples/asr_librispeech/run.sh | 4 +- fairseq/criterions/cross_entropy_with_wer.py | 1 + 6 files changed, 262 insertions(+), 182 deletions(-) create mode 100644 README_fairseq.md create mode 100644 espresso_logo.png diff --git a/LICENSE b/LICENSE index b96dcb048..13a6f665f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,8 @@ MIT License -Copyright (c) Facebook, Inc. and its affiliates. +Copyright for the original fairseq code are held by Facebook, Inc. and its +affiliates as part of project Espresso. All other copyright for project Espresso +are held by Espresso authors. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index cc1c76ec3..87dc730a6 100644 --- a/README.md +++ b/README.md @@ -1,143 +1,49 @@ -

- -
-
- MIT License - Latest Release - Build Status - Documentation Status -

- --------------------------------------------------------------------------------- - -Fairseq(-py) is a sequence modeling toolkit that allows researchers and -developers to train custom models for translation, summarization, language -modeling and other text generation tasks. - -We provide reference implementations of various sequence modeling papers: - -
List of implemented papers

- -* **Convolutional Neural Networks (CNN)** - + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) - + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) - + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) - + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) - + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -* **LightConv and DynamicConv models** - + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -* **Long Short-Term Memory (LSTM) networks** - + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) -* **Transformer (self-attention) networks** - + Attention Is All You Need (Vaswani et al., 2017) - + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) - + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) - + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) - + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) - + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) - + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) - + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) - + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) - + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) -* **Non-autoregressive Transformers** - + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) - + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) - + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) - + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -* **Finetuning** - + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) - -

+ -### What's New: +# Espresso -* December 2020: [GottBERT model and code released](examples/gottbert/README.md) -* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework - * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) -* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) -* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) -* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) -* October 2020: [Added CRISS models and code](examples/criss/README.md) -* September 2020: [Added Linformer code](examples/linformer/README.md) -* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) -* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) -* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) -* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) - -
Previous updates

- -* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) -* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) -* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) -* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) -* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) -* February 2020: [mBART model and code released](examples/mbart/README.md) -* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) -* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) -* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) -* November 2019: [CamemBERT model and code released](examples/camembert/README.md) -* November 2019: [BART model and code released](examples/bart/README.md) -* November 2019: [XLM-R models and code released](examples/xlmr/README.md) -* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) -* August 2019: [WMT'19 models released](examples/wmt19/README.md) -* July 2019: fairseq relicensed under MIT license -* July 2019: [RoBERTa models and code released](examples/roberta/README.md) -* June 2019: [wav2vec models and code released](examples/wav2vec/README.md) - -

- -### Features: - -* multi-GPU training on one machine or across multiple machines (data and model parallel) -* fast generation on both CPU and GPU with multiple search algorithms implemented: - + beam search - + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) - + sampling (unconstrained, top-k and top-p/nucleus) - + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) -* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU -* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) -* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers -* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration - -We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) -with a convenient `torch.hub` interface: - -``` python -en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') -en2de.translate('Hello world', beam=5) -# 'Hallo Welt' -``` +Espresso is an open-source, modular, extensible end-to-end neural automatic speech recognition (ASR) toolkit based on the deep learning library [PyTorch](https://github.com/pytorch/pytorch) and the popular neural machine translation toolkit [`fairseq`](https://github.com/pytorch/fairseq). Espresso supports distributed training across GPUs and computing nodes, and features various decoding approaches commonly employed in ASR, including look-ahead word-based language model fusion, for which a fast, parallelized decoder is implemented. + +We provide state-of-the-art training recipes for the following speech datasets: + * [WSJ](https://github.com/freewym/espresso/tree/master/examples/asr_wsj) + * [LibriSpeech](https://github.com/freewym/espresso/tree/master/examples/asr_librispeech) + * [Switchboard](https://github.com/freewym/espresso/tree/master/examples/asr_swbd) + +### What's New: -See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) -and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. +* June 2020: Transformer recipes released. +* April 2020: Both [E2E LF-MMI](https://www.isca-speech.org/archive/Interspeech_2018/pdfs/1423.pdf) (using [PyChain](https://github.com/YiwenShaoStephen/pychain)) and Cross-Entropy training for hybrid ASR are now supported. WSJ recipes are provided [here](https://github.com/freewym/espresso/tree/master/examples/asr_wsj/run_chain_e2e.sh) and [here](https://github.com/freewym/espresso/tree/master/examples/asr_wsj/run_xent.sh) as examples, respectively. +* March 2020: SpecAugment is supported and relevant recipes are released. +* September 2019: We are in an effort of isolating Espresso from fairseq, resulting in a standalone package that can be directly `pip install`ed. # Requirements and Installation * [PyTorch](http://pytorch.org/) version >= 1.5.0 * Python version >= 3.6 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) -* **To install fairseq** and develop locally: +* **To install Espresso** from source and develop locally: ``` bash -git clone https://github.com/pytorch/fairseq -cd fairseq -pip install --editable ./ +git clone https://github.com/freewym/espresso +cd espresso +pip install --editable . # on MacOS: # CFLAGS="-stdlib=libc++" pip install --editable ./ +pip install kaldi_io +pip install sentencepiece +cd espresso/tools; make KALDI= +``` + +add your Python path to `PATH` variable in `examples/asr_/path.sh`, the current default is `~/anaconda3/bin`. -# to install the latest stable release (0.10.0) -# pip install fairseq==0.10.0 +kaldi\_io is required for reading kaldi scp files. sentencepiece is required for subword pieces training/encoding. +Kaldi is required for data preparation, feature extraction, scoring for some datasets (e.g., Switchboard), and decoding for all hybrid systems. +* If you want to use [PyChain](https://github.com/YiwenShaoStephen/pychain) for [LF-MMI](https://www.isca-speech.org/archive/Interspeech_2016/pdfs/0595.PDF) training, you also need to install PyChain (and OpenFst): + +edit `PYTHON_DIR` variable in `espresso/tools/Makefile` (default: `~/anaconda3/bin`), and then +```bash +cd espresso/tools; make openfst pychain ``` * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: @@ -150,67 +56,22 @@ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cud --global-option="--fast_multihead_attn" ./ ``` -* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` -* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` - as command line options to `nvidia-docker run` . - -# Getting Started - -The [full documentation](https://fairseq.readthedocs.io/) contains instructions -for getting started, training new models and extending fairseq with new model -types and tasks. - -# Pre-trained models and examples - -We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, -as well as example training and evaluation commands. - -* [Translation](examples/translation/README.md): convolutional and transformer models are available -* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available - -We also have more detailed READMEs to reproduce results from specific papers: - -* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) -* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) -* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) -* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) -* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) -* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) -* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) -* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) -* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) -* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) -* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) -* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) -* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) -* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) -* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) -* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) -* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) - -# Join the fairseq community - -* Twitter: https://twitter.com/fairseq -* Facebook page: https://www.facebook.com/groups/fairseq.users -* Google group: https://groups.google.com/forum/#!forum/fairseq-users - # License -fairseq(-py) is MIT-licensed. -The license applies to the pre-trained models as well. +Espresso is MIT-licensed. # Citation -Please cite as: +Please cite Espresso as: ``` bibtex -@inproceedings{ott2019fairseq, - title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, - author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, - booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, +@inproceedings{wang2019espresso, + title = {Espresso: A Fast End-to-end Neural Speech Recognition Toolkit}, + author = {Yiming Wang and Tongfei Chen and Hainan Xu + and Shuoyang Ding and Hang Lv and Yiwen Shao + and Nanyun Peng and Lei Xie and Shinji Watanabe + and Sanjeev Khudanpur}, + booktitle = {2019 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU)}, year = {2019}, } ``` diff --git a/README_fairseq.md b/README_fairseq.md new file mode 100644 index 000000000..cc1c76ec3 --- /dev/null +++ b/README_fairseq.md @@ -0,0 +1,216 @@ +

+ +
+
+ MIT License + Latest Release + Build Status + Documentation Status +

+ +-------------------------------------------------------------------------------- + +Fairseq(-py) is a sequence modeling toolkit that allows researchers and +developers to train custom models for translation, summarization, language +modeling and other text generation tasks. + +We provide reference implementations of various sequence modeling papers: + +
List of implemented papers

+ +* **Convolutional Neural Networks (CNN)** + + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) + + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) + + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) + + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) + + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* **LightConv and DynamicConv models** + + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* **Long Short-Term Memory (LSTM) networks** + + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) +* **Transformer (self-attention) networks** + + Attention Is All You Need (Vaswani et al., 2017) + + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) + + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) + + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) + + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) + + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) + + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) + + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) + + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) + + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) + + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) + + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) + + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) + + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) +* **Non-autoregressive Transformers** + + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) + + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) + + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) + + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) + + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* **Finetuning** + + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) + +

+ +### What's New: + +* December 2020: [GottBERT model and code released](examples/gottbert/README.md) +* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework + * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) +* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) +* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) +* October 2020: [Added CRISS models and code](examples/criss/README.md) +* September 2020: [Added Linformer code](examples/linformer/README.md) +* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) +* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) +* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) +* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) + +
Previous updates

+ +* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) +* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) +* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) +* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) +* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) +* February 2020: [mBART model and code released](examples/mbart/README.md) +* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) +* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) +* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) +* November 2019: [CamemBERT model and code released](examples/camembert/README.md) +* November 2019: [BART model and code released](examples/bart/README.md) +* November 2019: [XLM-R models and code released](examples/xlmr/README.md) +* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) +* August 2019: [WMT'19 models released](examples/wmt19/README.md) +* July 2019: fairseq relicensed under MIT license +* July 2019: [RoBERTa models and code released](examples/roberta/README.md) +* June 2019: [wav2vec models and code released](examples/wav2vec/README.md) + +

+ +### Features: + +* multi-GPU training on one machine or across multiple machines (data and model parallel) +* fast generation on both CPU and GPU with multiple search algorithms implemented: + + beam search + + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + + sampling (unconstrained, top-k and top-p/nucleus) + + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) +* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU +* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) +* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers +* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration + +We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) +with a convenient `torch.hub` interface: + +``` python +en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') +en2de.translate('Hello world', beam=5) +# 'Hallo Welt' +``` + +See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) +and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. + +# Requirements and Installation + +* [PyTorch](http://pytorch.org/) version >= 1.5.0 +* Python version >= 3.6 +* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) +* **To install fairseq** and develop locally: + +``` bash +git clone https://github.com/pytorch/fairseq +cd fairseq +pip install --editable ./ + +# on MacOS: +# CFLAGS="-stdlib=libc++" pip install --editable ./ + +# to install the latest stable release (0.10.0) +# pip install fairseq==0.10.0 +``` + +* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: + +``` bash +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ + --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ + --global-option="--fast_multihead_attn" ./ +``` + +* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` +* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` + as command line options to `nvidia-docker run` . + +# Getting Started + +The [full documentation](https://fairseq.readthedocs.io/) contains instructions +for getting started, training new models and extending fairseq with new model +types and tasks. + +# Pre-trained models and examples + +We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, +as well as example training and evaluation commands. + +* [Translation](examples/translation/README.md): convolutional and transformer models are available +* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available + +We also have more detailed READMEs to reproduce results from specific papers: + +* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) +* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) +* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) +* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) +* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) +* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) +* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) +* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) +* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) +* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) +* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) +* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) +* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) +* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) +* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) +* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) +* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) +* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) +* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) +* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) + +# Join the fairseq community + +* Twitter: https://twitter.com/fairseq +* Facebook page: https://www.facebook.com/groups/fairseq.users +* Google group: https://groups.google.com/forum/#!forum/fairseq-users + +# License + +fairseq(-py) is MIT-licensed. +The license applies to the pre-trained models as well. + +# Citation + +Please cite as: + +``` bibtex +@inproceedings{ott2019fairseq, + title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, + author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, + booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, + year = {2019}, +} +``` diff --git a/espresso_logo.png b/espresso_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..5c7ee7b978dae60f8e531e74f97f8ef6e618faba GIT binary patch literal 13344 zcmdUWbyOVR(jWr^3_e&OSa5e8oWXSEP@Ajcb7nLf(3#F5AFmA1Si?~ zz4!Lp?|kR%|GRTepYE<(b<3-}@2!c|f+^u)Q(&W@px~(}%WI>c09cR*D+mJ#p}-YM zL7o8K+Dfu0wO^?YkvCXyWg~AC6ddAz4ggAKHaQ9kDvpz`A;M5YUChSAjoZrB!`hD9 z&kc^mMnRGA6GL9Q*&(cG{oGvLy~X?_>HmcxhP?g<%tKH6FBF7}B)y@A7Ok9zmmRGT zw;(q!y%aVrEvzW?C=ci8{L*K+hgcpz=<jILW2Q^oXr0#iPJ61q|nf?>Ao9a6D%&x6h^rkv!1t)6`NtJ%xlVC=SWkU zK@Vr8{TI$^Rc0+4QwBo+YQF(uc1mZWol3i*29)>yRc>PZwTQjuwJ_jZPQ7YkMjv9g zFu>NdCRk?+2Fk{TIl^dss`VAp4T;pnq$%(gDi3jJ8kpxns6FyBQOQbcg+t{_Fg>ts z>#r3ylV8}VP#Gi`v%9(e!&h0MZ8yD=Z-dczh*Bk%mf3@jn^`dt|CB~_fHb}3L-d6W!{daG zLWI3HFnOIhr?ybrTJo+VG+N{a44xd7bjA9K7EFk3Cs3^Z!d>W`^+z|4hFqNmEU?`U zltT-aN2=@9@ZF0qlM2L~Bjif-bA^MuA=mVTnQgfY^ak!#VUv`RWQSuT%y%3SmRXTL z!brvsrF{?j66Ra6a+Fz4=)Ff67xzN6{sRR`D$@`MHT)l4)z*N0#S3_0hi^BuFN+PK zy*|J}+MlR{TzH5ePUutxw^(vX>}1VV;gMH~1hn@Z=*`}aY@G~m2T?MX1(Cs+4*-=- z8w^iCTKX;BiFsIRhQ=CO^^d-q{O-i^Kb`u^&N-@H*r;ttOykc(&gC#pCYD(dz&5KHtEs?c0Ax%wL~UK6?%?UaAT2!!n@ zO+n+z$x(=ygw&vud@#UxPB~f4A*n1vrk)VISzP{?He*iAEc1NVCz$!d@4uqPl60W5t~2s!5y8p~S@?Xs<8G@0$Ovb{G43gqHP)>sI3j;;1*I zLnUNl(B89uY*ZhkJBS3+HsrQ{!7>B?<6$eFO;Fi*HZa*+Tk-@J9LS>e*)&21gp^?? zUnjXUQKiTTKMt;reX3$DDXRDq`<{D;PLmcc1{l(C2G6EVF8Ad)RS{`l>WzeX8^X*dIYuA?OSVgp0+wvo_X2dQ5 zuqG`UQrj_o`LlRzBf^f6;tq!0USo&}=H@4&Or=h<;-)EEOL}S%Q%Ez=W)c_q!qwYd z__D}Pc9bp2>ngv_=$Qr-5Cb1yBb_cBI2fOyA#5}@_2=jH_XFkF-zfyKkNE7Ug`!I* zq7bu?zSDmXOW977cS5A0AvmZ|1!Uko4uum#Pu5QC;&e`JBC2(3TGZNG*(-WQFY@3iKnv1i>9$q?U5?mz6R%}3)m;4 zU*k3@SF^|&GV;m%11@N?7P?CH8ggyy&9!>Am~(#}7w*vk3=gBH?H?`PPI|GFF-Fle z8Z}h-j_|vwwIr#663s0&vNAAB3L;bMyX?y=|b6sNc`y zy*JcM!;G=z#6da%B~UwW%!%X+G9F%{`NkS`yz<5YBNNsRnMZ&PWL}XP&%WfgIc?ig zn%5}0d>T$UX8<9Pmcdd+iPGvKK={1cw2&W5uY4M<)ef=3}oCg*}}GEquH?aBmzPzhTK27gsGA7udfMa$Sm>j)R~WwoQRaQndv zh{e*uZ;Ig&553O~u^{VA>XAm(iY(3QMY9p3pWxJrKzg=p-o)7E2bqR&PIvU7_aCmj zXu+Ze5ehm^K)q>WL5ZY@CENEUE6WGvE-&)&nD2HPXg9=AKLkQs{1~l;JuzXQ8dZ^H zqmNju1thQP9cMVf{8RaZp{D6h6pRh^!%VjG7J}_W<0XiB-}e0#rJ+)~7@LeEI)D@; zuLxnXvWduK-E^u+q>yaYD{gjnsaNLaA$*enr-0;ehb@q_WN1`304oFoJd+0ltixI4 z$H2(JejJvcXxFb3H|z*_Mr+8!mCJuZl)ZxJp*-xv+*d;-O@i*PIJ&9sU#p@8$I37Y zvWO}tiCh=b9LL3EYE{V6US;Zx`y<}yYh+n^#05B#e9drjIWD&PUYmbPWnDlB%%%kg zOtR^7@nuorzCRjk9}Msc%X=>-Lp#^6Vl6C6{v1=ez|%bV4r}q8A)R)E54GSK%*W>7 zve)~~NXF|*F;LmSFkZ5uX2H>1m>~09KNd#I65tRswCOo9)zcbubBDXRANkxb+pdyk z%U;`atD@b}G+mU>`_-yn!F2?R(uc%8ge@0vxuZK@q|8hZf=W30A`~gi35!}BLsqF! zN393}vA=BzSw~AT!nK@XSX^y=2EgP@z8O8$Po0@pczs@2sJXOYu}L=3kItC$-|X|h zPZ0wdbXA1ax|XkJrjKb#y9u3J1zXfXQ7~^YaJ(`w83zp;2luHeK&6_6EI*YV4-86$ ztEx^gM{?_D0Rc=6xpT7YC8SVV-Tn%9V{VLZ<@9c;w&)-kNTSn6m<|aElL82pi4aMY zRFtTC-O?DttE*bgV&|v|9Rhnl5tR*XkbX1RHWgG=rzGViy zs?B1%tsGsT6#=9*4iFBSzgr@yDo8w-az50V0wdEo7fL-DCtD!uB-D{1N1ySDdSG4; ztFwQFy`+_u4(zv?62FvEQdKzTB;{N+81U<*LsVjFFHRYt=+N76U5e9{nmFnjn>#uo;dYQP>&IzfgFTMpI^8}DA zyw@34FLX$$#blo>J!~$KhScHDLEJf>-YIDRj-C zStw;uiWfH-^yI@gE=Cl>O{uSLv**7ww4KXRP^--@RxA2UtIx0%`m&2V?#|!LK7^n# zY>K>&yU=-_o}n>Re_vm#xPA0T%1&!y1|fiX$>=& zDTyzgwYV-hX#37nCK=Ml3eKF}WziaH4_FSHr^QO#O8;QfkD1N?!{3Pk@9!x4k>Du% zJ2OR@AGGhD-We29G*Oe96sWEVI*4}f^Q)!B5kyTot@6e$*;a@Ir(nHE8W_PVc`jwm zmQ>Bz`TlJKYX%|+gM|n|-EVGi2dI_-vEU*+Uz`s(HjGBfC)Q(SP%a4|p2rEzX5cR* zRIFut;4f}d23^GnSKJ*bW&G9z9}3|p;4t!jcC5m$wZ z_hm-o2TU+!&*YftUVZq9)puqsw`v_Oa?rl)G zbZO@bZFcVK_f}YCR6L$ygrVc~*+C*Tvo^v7pTFv_sc4(u5r*MiP@P=f8{Uatq(!DN zcLKN*+wz}=xk)JwvpmVP&hG1%_yV=$QO-26P_XC+HLv3mQ_&ZW_y-K`swC#(TCT*YF7ut2K_4UmViPx zMDLcl=Tk7ldFX>idVlGjy4A#&&VW9xTD0hJTr+$f8a8RYEPiR0!xk{iy5?+LrkGs4 zr}uHnpud9sLp*@7Z%6mKerhbj-W2U&Y*)%=ob(mH=e>lhdh1DNZ<%(L{zG`K?~Cq4 zuO%Jv=1o5X6n?fjk$mKo&lQeiN38lRou zT0jb@8D}zn@>)?qR2fF2yxDh zs~_K<^sboP#Mx{v50{Fqi~gxeXCM#m9l!A;A0F?yT(fHzX&9fXB!#J^X;Bo;j5aur zG_LeBUPZ(Ee>FO4OyJ`}j20hqKC6~~D~K=SFlX|%^zDa;4<763n{Z&ce*$$FN>d1R zylhQ5^acucp0_#)o*ctTuHtlpH|<;{a}wCTLmWNkd8Iu1D;PB=%~VZvFW;-q+PoPJ z=4&&R--Tc{PkA2V9I|s9+`T2aVp*dYZs-4)jdq%htNl!6Nr(lp&IzY+@Z^z)D$%}W z_$(3o8+aT@h5x}~giR_79CfN9fAGi`H5>E^MeRI#9k<4E=@EmXglL-J)$9*TgqPRc z%pv<-hc#Ak*sn!ybns3t>~xGh+lw!2Fj!{y7RTD?J7bwPkI3`y>3y0Wg-W#V-J@m1 zn(MjC3Hc6)X5rcwIQKN|{rjsAJ2u6~Y(W}mD`vbc;z=f-38|xYITR1|XEtals5f=z zDf@N=eKAW(r2(mmZdsHWp9O-OlOxJE4s) zy?kVg@>afB@Vit-2QrjMhii&OPoTGjN*g zw)8T;HX%ppP>Kem5BMio=O?3_EjUbklt5SHI!4D4WX`R)d&e>!#2>ZJQY z`P<0S6-Lw>?cP>#XnnK^WHL|H=E5h=)bvH9C%p+K)1@e*pGEeZlzVspG&cK)<8iNdQyKbpsaoYX` z0tU&;ZJ3{9k%pNEOLr3_4vY?)l=QPvm$j_Ef>s;^s}bYv&8!q$nnI}0`6a*OCnt$e zfGjuiME;sq>dYHw`#q!$G^U7RGQd72>vh?rF6V5(u~XsFWLfN!Srs|^{S`&YjON6? zy7FPLV-3B^Oi>rnuWsXe!k>P{>9%_V*5^Za`2g%smNWRdr`3iz_2S)nCm)T(qW-{} zB-nEUxwF~Xy{wdn7)%w;qU#d;=!+!OPrWdC6K76~@DB)*gn5d5GL{>TIU0*d<<$pYXE!7e%9(_uJ?OI8M1v4o!um$J3u!j-9}!v(0yY zg;(Uv%%5Jvdf>9+_f^f|qyvLGn^8O?(L>}P5YT*_@VJZ_&q5c@n^&oBH#||Z#dl=q z^u`o(Lu4@dmJSdNMzmKF*ub=Hq{Y+GCsV%c;iZctDQt0mLVERGeBm@Y0bu9_A9*OX z#wq*O6&oq7vs*I-{zokdKLgI%;Xgk4uX_7$NR&8C2`WxM!aKXe-!ACAGH(Ov6@z6Q z-$s*7JLA--Yb{Q45iBPYg_iAMI*F`@w^5xJ2(qKBQm$KF5)piW$FX>YH?O+isfZ64>xqCUTddvo9OH{J=5Dni;L?-A}eHfC>HZA|gx_)pQ$Ul@d zFR)%^)SV%Wtl_L8Q4ppl`CUqe($aT%D_??%WvI7h6YsoFs<5KlEC+4p8gjltaX`og zh-j=slh9C+B!t~!v-7cLwkq;RVFQ4{I+E_Wo2bQbT$^v19wbxDj*vFIXv8Q}{g6|0 z6Ytv5Qx7g)b7p$-6!B)%S1rje5I54y4Z3{9s!XrMTdm%F z;dRN48q10je)QS-V(e>VwsT+=oENAW%N5ijAG6yaIIKaIF_m7!QFMu1+2`DemO}X; z!q$d*Ar?oozJ=^0n+hTePCF?M1>UqBI6beTv;$0UR@=mhNidqEO`?5(^g4uJr)EP_sf;-UYh7Ts%k9qJ+GdnTiA{mJlxDw)J@_O8^Wv3BD;f%9k1gvKhnnKr z5^bG$$lmomtj^||lEKJ9^x5r_JP0lQxpyzBJ8Sl3V85WQuq@0oNUga?6f%bro`AI=<#9iy=+$RD)EUDZFXafn#fAM`G45yggKm+3|a z4mZMkUeB;I0L22TsFlF5z=WfQlv2lN_N+Y@*tv$bSj{*D@1=F^9i>WMw7^n+Ygy3*mjN4vb5d4naI>Vdl z&tCa3U0?$%4gbCEc(c%U(EUBHAh@fwyjpe)8w(Z8>q;5V{)1-G`PzyBxr)D6hPb0) zxh=Jfu`%n_Hfo-c#mG1(d1e#zSz3cS*tBM|C_pJ&Fp=59pXK+_UxFXa2nS5b!0(Ay ztBqM~SymmMxF(%jXo!ioq)71vp zBJ+Co_L*W?go2YXs0yyy>ymzt%KZkNgczSf`p3WIEwB;6>M!e~i3Pb^`eSeTYI!Gp zH2OUCAiKg4F}Sb* z>{_zCm-v{wtU0Q5{46H742(%G*l%6XeE|eEl+xzv26PI4N|$S#hJ~Im$k?DsPj69N z1;$1Q(mpStVq_*0lZ8dX^VWvjhOIx-naqlQky$5TK^=~@9JlE^FFWqXVlQ_&zB>)9 zIsH+6GXJCgBIzI$chTeS8uKl3rr?p!}EbL(D>yj!#V$;6YbdWsF@0rC;UlHXM!71hQ#UB@bzFC^vxW)ckM` zDpM&^4dYuWr*_-~R($#f0+TwO2KpMYm#Ri3Gi5(!V9ko{Zew!Txo$?08O6dY<1U0) zEc>qygpX`MEP^9-D0Dh1-*IIxo4z)=_R-tPq*Qdueq}86YAt4eW;U2 zpDkqP<2B-)fEQ~0A%T1~4_zjaV)H%e2u*lX73$lJmdK;Y#!>GV^wS^dksOyLEFeR- z;69I|pJxlUpc;zhW4FE^x_4_Lx$0D-M0v-g-Ue5~6V=tRP0@-OFo4}8cV zb{^S7DXOu1z{JQpQMPO9s1b=35BBLy9WOi|8MmN;Xh0f=*Gd)#7isX9mn?>3EFXiS zdyl1!N)~zw+wiv=2-;F1QQEP{ZwN^f2$c3*u7_NZYL-Qq2W7U!jMiNJ`06RK3YYg~ zQuA*O7lrt}VxYSwPWZ4}jklW@`lfDx2{iUsy;zAngf5rZg0|PC2N>V_0=m&CG^(qL zN2917NT1D1p37Hhs<5mdpE+xZ+_CCCb$Skl)h18|aVj-ZCYAUuHTy^LzG*VUY9=XZsZ^1Nz=XebXeW6qv~^G zzTU0bp4pqV!odr4#_CGfp2P}jK({ND+~{|jWH!b_S^S21Dc*AKyZ>kHsxPf2di3>8 zk@Ro>0ffpG9jeiCleb8Bab^>G;@pc@M6CwgxSYp&=TT>egsWcQAe9=p*M@OJi$QO6 zviqF8gT(MDu0W@+#fY#f&i8U12U@b8EU&v4fJ7Q`@VqQr-36|NOT-1@;rIAFc@ZD> zQa@IZs}8!VmzR4U-S`w1g(-W-;5UrDq{;Kay<+dM*zb9&URxWRS(-m=Vmn`LXo;O< zXc~`kd>o5M5#E8P27iIgUAJdu?wl|2MTHsSj z44cQoj7ekns$vz!MUs1)9FnS^uENWU)GKoqo@!bQDZ5Hu1P`Z|cTVNIJ9%yM0Q2=e zi6PgI_J|CWL|F}D_XNxFb;K||--}hvSCbz>dI3#u6Jv&@7uHmBa_N_grOd#Q_%Qz!c3c1+G^%f^4#=n&6=4wfHp5R4F z%{4uK&F=$anNotA5glf2zrJgNY!6#LsvU}o7^>oab<#Kqwml^2u8iPuq*8*915OpW z-3U9w6=pP&n?Axdj<59MEIzhoe%y@SZF|Y=#uYITb(!V)bq#E<>fFZ5P-Q*a?Oi_J}7YwoRhTaKCoetx_j_ zAzq{;Il~_Kdy;x^X*--Tg*Y7H!1A=fg~3}Q9&Pq2ty*-}s2rsh1z9aHPTDwitweGB ztm^Pf*!yk0QEeyHoCM~|VmON&1h&Tl-5aYV6CoAh!`UIRlN5|yCZOKZOX1MKw=TeKmFkCf5>(zso z?YQ{6n-bExp3|Y2Qj)d73K$BQsP^*FRRzn2Gz2Mv>TeqL^eXtqjv-ga#RGgAz>@?V zz*((f;VU#RIp$6+uFH)k(@>12LA2%kdLgK1FNi}N6 z?C>d}7|;H2Xzkqbi#_7tIMns#PVly@er3BnM}El1Md@K0@6XP=T%t{@3n5-mEfh#` z_A!bgCLi09{@#<+wH}gc9(?k>WO?m$>71A+Ux)iHq)o8n;u&DV$CQ#FKrO#UeYX{q zSF@k>LYG(7dA8_6!rc#fL0S64!axQIm%j4}{dBhe4$CYR`qiVZs+^lG#moL?%Ya zg;~wW{5-D0PZQI;DdYntUyFJ&_!7QwEH7xDU&L{tCQrbfkniB+-m^&$n zvpTSYF2|;MK^h9GmG+M+9P`;I@sa2eye_bS7&>!JO0nbDwO_rr&n!atRu-;m3p+)G znE$FXz%utR#I@}KsJ8(>^2c&@v=HBGdJsRReQa)cq;jWrS1nG6DE1vZa-g0szM2<& z6%YD@-jUkrY$H6Wv$;*;KB*F8`JjD|juH0jO^n(3(QKE}S^KIs>ksZsc|E+Go5T{t zYKMU@)2M|n*s!~4UeYA_D}UNg>2Lztm1q>3COy5`0{)9=VtoOq76Wj{|x=V8F||UUG55+u_h%jYVFVY8rz@vDZW&-R%H{22^k1mE1NuP2134mn2( zF~q1)H_2bIVV5Da2A5#@emvTTssXuc_-L|*12L40CaT)qpX&@udI(|nx$Ncf*L&te z#(Ioq1XaJL-N-f}Y!6S>F-D@7m3h;1XA<%)Xm?p-@%7ItY(ZtAa<^39@hbgdUyB2t zjmWTN&GemzwxXGvzjWF*;?2Py5NvooP7oAUSR#5T(+BA!56#_(u#c{7I+g1@E4@67 z^!?GC)UA|zAG<%T&?co6^H!>3L_mdC2_@q2B4_+yWKWcI=M9}Q_TPY;r@;AJk*C`r z3KYDbe%u`OfowT5L$iW$lQ>(iExxOobX>hU$VTMm1aR(Z^h1^x6fI`vL)3s1>JP8n z3IhN9)(!ADZK@`6zQ->drhHM-6Z6*gn^MLB8f%?}YNspktiP>0XW`XT>jPQlke6Oi zDf^*%b;bFQmTKcopp|d6nFRix0hQ*FWe0}zZP;^?Qu@PXOh>dH{z>*RcD><0YbH)h zzRa^G9A_M2+g};v@7nYgeOsnxNr?um7AGc z-%Le^hKhy<>D2tcHnWsqdO@BQ{21TaWVW2It!GimUrF>mL-0-yMeaRB>U2k>0yg#( z1MW7(jplK);P~;}h4P;>yC=G3yoPv(Af4TdP1c8%IhV2=CTpY`;rVJc7r)F2se~ z_ZHnhVt#fQ9Ph#T%RTqdr2fhNro<7$&u_Il1k}(CLKE)x7xiFw)D02dvsAr#8Lk>l zyh2t$M$LZ-}FDczS+A@>y)8) z-9{f?hOWmlT}=t|Fs#p%(_hshd&zPgOPYGeMb0DKI5?j!qu|{AIt;t&qZ&f z?!JxolzAWWU5jinX5zU9pvgLVy4pafps9)qyq0w)Np;c5Un_np;6`rgt+l_pQ1lg= zF7v-xzkR;kta9_Hhz|$x>!2NV9 zP{3ps?=-g0sxUwcq~KjJHyCyK+=Jkc0<3UQeUd0h%FFJKy>-}w`USFQ6R}gC=^K$l zq)q04H=&6VCNwJva$=#<`w+3at+SbL4yLFNGrgvlV%KK{=`@doHKxDJqEKO7^E(^q zStVys^2dG;ywRaVgsM_&Ut7&C_aWZ8JSYL~ zThrYoM#?5Lpq`*oPuGwyE6DV4$d4kv zo$Rl_vkTJubAVz?=Ys_RmHsFS?J2xTR?`Dk7~(VcsePdgXSa7?|w*0>-^?Ko) z`0zZN?DYq4pPkBy^S>x1y|gL7LvADoWF;Ab`p3o$8k#%4$Wk)JkAq}vj6s2e->*VE zJ1Q+n>%Lh)aBcL|KZ?s4Ji+%6jTg5XYt%(moOORNv?0_eEaEQDh2&(+2u8VQL3N7i zA(FjjH`EbaaQvhaf?o}NX9)2pK>6aYZ)k6*A;KCJW_s^jqP=3C)jK>1SAAuh4=Cwt z68-H+37jJ!Y}=;~`}}!CP6dt!p!O(N$(s+3F|A_UK%eT)TYy$Qe$26#F)<87d(pa7 zz&A;W)Ma*jL7P5XosT;WpjPWE7nzwrX7Ak~kyHx~=X8s(I>ZH)mSC&l*}Xh}Yk^A_ z(v~`GC3b^TjWCuTpHMCM5?P-9m^iW+BItr%(`J}D;|k4P@Igb3-1yu0;6wzRV_2j4 zlBf;L$9k;bQb^w8EOAJEMzz&%&r`xUy!*S?UFUhp^}n_mwx}^m`AE?=h^1{)#GMOW-3T z(?Of##+zEfyKzNJGcgT|!P3RgjcVZjM%roqoM#)#qiMY zsc$!*B3C84djrx>rBNzRz&a?WJooVUBXGAy4WNDz#P_*qn<=6~TsXWdUKBJEucMDK z(tU$HH3A{(SMjK)VMjBc_z|tPO0KRK5}U4-(CN=W7lWDBs3Tw0@ayhH$27hF)UE>x z;TJ9Rbe9SK+-^q||2F#Q0Aq(aK1_V?$FWm9g!tS#<)#Tg-ip9CGQw^|7t!YV7u=@@ zbuxoQ`Z#x^temei!quQg?u6JfqbTfH`XW(M>(^}K&W_p|6@nV_N~_7URM-O1w^HBN zIoE31UaV-Ac6DS^TayQ`2?qID*z>*dX2?_LbuuEQAm-bx^t6J)a(>7Zv2S-+8mG&? z^CB{y|IRC;v5bci8v8kZjCy6|yG(SHC@`Z7eS1mYlr15#>v7|Wr@z7y+?uTv+(OX_ z98E{fCdQ$`yb)HL>aW>KUj*NDd7bbQaKrodf}q6ISarNZ~mu5VgQ zM$hbN2w_aqW{dJ81xlZz@~!gFwHQ=wfA2D=uPw;6e5O=#eG2%e7etLYn(aGb7{}yu zcu5Bp=|`;cng-$fbs-4bl20k8F0y41RXHFn4MU8kV6bd^oF6@(`b>@CUTOGrpAY=v zU}9PlVpHw{DUkV#G)v&Dwc|& zw^=WLaxF!iQ+akv6CY}C!KA2B%IE|dPyZG>W%+KQ==AyhWebO67h9?=vKPJceSZdZ z{DpRmq_0!;Oiv_fx85S+(|||*8@To~!|154c)fJao5LS|j$EDjGk+5S?J=yZQU<`D zfR*!YpI~&1-__w-#@&~}#;HLGJFMh-v}&R~vdB)<_%*)2B$ct$vcDmpt7(Q0ptrK~ z(E2%O?^SqxMHFq7d;t24IGwB5^mA&z9?U6U`e2rnQ@7hLwx&ekv&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/fairseq/criterions/cross_entropy_with_wer.py index fad3d104e..92d5b602f 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/fairseq/criterions/cross_entropy_with_wer.py @@ -188,6 +188,7 @@ def forward(self, model, sample, reduce=True): sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, + 'nll_loss': utils.item(loss.data) if reduce else loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['target'].size(0), 'sample_size': sample_size, From 91d9fc7e29ef9a519291b6d92e1d16319e746c60 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 20 Sep 2019 14:45:59 -0400 Subject: [PATCH 037/119] compansate for the removal of torch.rand() from distributed_init() recently introduced in fairseq, to make ASR results reproducible --- fairseq/tasks/language_modeling_for_asr.py | 4 ++++ fairseq/tasks/speech_recognition.py | 4 ++++ speech_train.py | 6 +++--- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/fairseq/tasks/language_modeling_for_asr.py b/fairseq/tasks/language_modeling_for_asr.py index d06f8e0e1..b0b96bd3a 100644 --- a/fairseq/tasks/language_modeling_for_asr.py +++ b/fairseq/tasks/language_modeling_for_asr.py @@ -56,6 +56,10 @@ def add_args(parser): def __init__(self, args, dictionary, output_dictionary=None, targets=None): super().__init__(args, dictionary, output_dictionary, targets=targets) torch.backends.cudnn.deterministic = True + # Compansate for the removal of :func:`torch.rand()` from + # :func:`fairseq.distributed_utils.distributed_init()` by fairseq, + # to make previous experiments reproducible. + torch.rand(1) @classmethod def load_dictionary(cls, filename): diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 10700daef..971113f72 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -113,6 +113,10 @@ def __init__(self, args, dict, word_dict=None): self.word_dict = word_dict self.feat_in_channels = args.feat_in_channels torch.backends.cudnn.deterministic = True + # Compansate for the removel of :func:`torch.rand()` from + # :func:`fairseq.distributed_utils.distributed_init()` by fairseq, + # to make previous experiments reproducible. + torch.rand(1) @classmethod def setup_task(cls, args, **kwargs): diff --git a/speech_train.py b/speech_train.py index c240e8da4..ee273a139 100755 --- a/speech_train.py +++ b/speech_train.py @@ -104,9 +104,9 @@ def main(args, init_distributed=False): if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - if len(args.train_feat_files) > 1: - # sharded data: get train iterator for next epoch - epoch_itr = trainer.get_train_iterator(epoch_itr.epoch) + reload_dataset = len(args.train_feat_files) > 1 + # sharded data: get train iterator for next epoch + epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum)) From 6a8f4d30c858927054bf62998df39f411da9f975 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 25 Sep 2019 16:10:23 -0400 Subject: [PATCH 038/119] add backoff smoothing for unigram label smoothing --- fairseq/criterions/label_smoothed_cross_entropy_with_wer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py index 1ad541ead..ca3369458 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py @@ -63,6 +63,7 @@ def __init__(self, args, task): self.unigram_tensor = torch.cuda.FloatTensor(dict.count).unsqueeze(-1) \ if torch.cuda.is_available() and not args.cpu \ else torch.FloatTensor(dict.count).unsqueeze(-1) + self.unigram_tensor += args.unigram_pseudo_count # for further backoff self.unigram_tensor.div_(self.unigram_tensor.sum()) @staticmethod @@ -77,6 +78,9 @@ def add_args(parser): parser.add_argument('--smoothing-type', type=str, default='uniform', choices=['uniform', 'unigram', 'temporal'], help='label smoothing type. Default: uniform') + parser.add_argument('--unigram-pseudo-count', type=float, default=1.0, + metavar='C', help='pseudo count for unigram label ' + 'smoothing. Only relevant if --smoothing-type=unigram') parser.add_argument('--scheduled-sampling-probs', type=lambda p: eval_str_list(p), metavar='P_1,P_2,...,P_N', default=1.0, help='scheduled sampling probabilities of sampling the truth ' From 32825b5357ed9eb041a48a29f35fe43f43385f9e Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 28 Sep 2019 19:00:02 -0400 Subject: [PATCH 039/119] a better LM for Librispeech yielding better WERs; code adaptation/changes according to the commits on Sep 27, 2019 --- examples/asr_librispeech/run.sh | 6 +++--- fairseq/data/token_dictionary.py | 6 ++++-- speech_train.py | 5 +++++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 6bc4487a7..0231bb63c 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -170,7 +170,7 @@ if [ ${stage} -le 5 ]; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval 8000 --log-format simple \ - --num-workers 0 --max-tokens 30720 --max-sentences 1024 --curriculum 1 \ + --num-workers 0 --max-tokens 32000 --max-sentences 1024 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ --distributed-world-size $ngpus --distributed-port 100 \ --max-epoch 30 --optimizer adam --lr 0.001 --clip-norm 1.0 \ @@ -230,7 +230,7 @@ if [ ${stage} -le 8 ]; then decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.42 --coverage-weight 0.0 --eos-factor 1.5" + opts="$opts --lm-weight 0.47 --coverage-weight 0.0 --eos-factor 1.5" decode_affix=shallow_fusion fi for dataset in $test_set; do @@ -241,7 +241,7 @@ if [ ${stage} -le 8 ]; then --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ - --path $path --beam 40 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ + --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index dff3177d5..978abf92c 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -51,8 +51,10 @@ def token_string(i): else: return self[i] - sent = ' '.join(token_string(i) for i in tensor if i != self.eos() and \ - i != self.pad()) + if hasattr(self, 'bos_index'): + sent = ' '.join(token_string(i) for i in tensor if (i != self.eos()) and (i != self.bos()) and (i != self.pad())) + else: + sent = ' '.join(token_string(i) for i in tensor if i != self.eos() and i != self.pad()) return data_utils.process_bpe_symbol(sent, bpe_symbol) def bos(self): diff --git a/speech_train.py b/speech_train.py index ee273a139..23c211b91 100755 --- a/speech_train.py +++ b/speech_train.py @@ -212,6 +212,11 @@ def get_training_stats(trainer): def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" + + if args.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(args.fixed_validation_seed) + valid_losses = [] for subset in subsets: # Initialize data iterator From c63caa801a7b1c7fcf727c117dfb2f82e6795376 Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 30 Sep 2019 16:07:10 -0400 Subject: [PATCH 040/119] code adaptation/changes according to the commits on Sep 30, 2019 --- fairseq/data/speech_dataset.py | 3 ++- fairseq/tasks/speech_recognition.py | 4 ++-- speech_recognize.py | 13 ++++++++++--- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/fairseq/data/speech_dataset.py b/fairseq/data/speech_dataset.py index 9c7a02e87..d5533d23a 100644 --- a/fairseq/data/speech_dataset.py +++ b/fairseq/data/speech_dataset.py @@ -145,12 +145,13 @@ def _match_src_tgt(self): def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None src_item = self.src[index] - return { + example = { 'id': index, 'utt_id': self.src.utt_ids[index], 'source': src_item, 'target': tgt_item, } + return example def __len__(self): return len(self.src) diff --git a/fairseq/tasks/speech_recognition.py b/fairseq/tasks/speech_recognition.py index 971113f72..4d5df3388 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/fairseq/tasks/speech_recognition.py @@ -174,8 +174,8 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): assert len(feat_files) > 0 and len(feat_files) == len(text_files) file_pairs = zip(feat_files, text_files) for feat, text in file_pairs: - assert ScpCachedDataset.exists(feat) - assert text is None or TokenTextDataset.exists(text) + assert ScpCachedDataset.exists(feat), feat + ' does not exists' + assert text is None or TokenTextDataset.exists(text), text + ' does not exists' src_datasets.append(ScpCachedDataset(feat, ordered_prefetch=True)) print('| {} {} examples'.format(feat, len(src_datasets[-1]))) if text is not None: diff --git a/speech_recognize.py b/speech_recognize.py index 93f378f58..e40248edf 100755 --- a/speech_recognize.py +++ b/speech_recognize.py @@ -18,7 +18,7 @@ from fairseq.models.external_language_model import MultiLevelLanguageModel from fairseq.models.tensorized_lookahead_language_model import TensorizedLookaheadLanguageModel from fairseq.utils import import_user_module -from speech_tools.utils import plot_attention +from speech_tools.utils import plot_attention, sequence_mask def main(args): @@ -125,6 +125,13 @@ def main(args): num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) + # obtain nonpad mask of encoder output to plot attentions + if args.print_alignment: + net_input = sample['net_input'] + src_tokens = net_input['src_tokens'] + output_lengths = models[0].encoder.output_lengths(net_input['src_lengths']) + nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) + for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None utt_id = sample['utt_id'][i] @@ -149,8 +156,8 @@ def main(args): # Score and obtain attention only the top hypothesis if j == 0: # src_len x tgt_len - attention = hypo['attention'].float().cpu() \ - if hypo['attention'] is not None else None + attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \ + if args.print_alignment and hypo['attention'] is not None else None if args.print_alignment and attention is not None: save_dir = os.path.join(args.results_path, 'attn_plots') os.makedirs(save_dir, exist_ok=True) From 0e610d56e8280d55e3144200f90c44e1dda5f1e7 Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 12 Oct 2019 21:41:25 -0400 Subject: [PATCH 041/119] set --distributed-port=-1 if ngpus=1; code adaptation/changes according to the commits on Oct 11, 2019 --- examples/asr_librispeech/path.sh | 1 - examples/asr_librispeech/run.sh | 6 +++--- examples/asr_swbd/path.sh | 1 - examples/asr_swbd/run.sh | 6 +++--- examples/asr_wsj/run.sh | 8 ++++---- speech_train.py | 11 +++++++++++ 6 files changed, 21 insertions(+), 12 deletions(-) diff --git a/examples/asr_librispeech/path.sh b/examples/asr_librispeech/path.sh index d0ebe2157..19308a610 100644 --- a/examples/asr_librispeech/path.sh +++ b/examples/asr_librispeech/path.sh @@ -11,6 +11,5 @@ export LC_ALL=C export PATH=~/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH -export PATH=$MAIN_ROOT/speech_tools/sentencepiece/build/src:$PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 0231bb63c..6cdd3712a 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -172,7 +172,7 @@ if [ ${stage} -le 5 ]; then --log-interval 8000 --log-format simple \ --num-workers 0 --max-tokens 32000 --max-sentences 1024 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ - --distributed-world-size $ngpus --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 30 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 8000 \ @@ -208,8 +208,8 @@ if [ ${stage} -le 7 ]; then CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ - --valid-subset $valid_subset --max-sentences-valid 48 \ - --distributed-world-size $ngpus --distributed-port 100 --ddp-backend no_c10d \ + --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ + --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ diff --git a/examples/asr_swbd/path.sh b/examples/asr_swbd/path.sh index d0ebe2157..19308a610 100644 --- a/examples/asr_swbd/path.sh +++ b/examples/asr_swbd/path.sh @@ -11,6 +11,5 @@ export LC_ALL=C export PATH=~/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH -export PATH=$MAIN_ROOT/speech_tools/sentencepiece/build/src:$PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 971cc0efc..291b01071 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -211,7 +211,7 @@ if [ $stage -le 4 ]; then --log-interval 500 --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 1024 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 500 \ @@ -249,8 +249,8 @@ if [ $stage -le 6 ]; then CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ - --valid-subset $valid_subset --max-sentences-valid 64 \ - --distributed-world-size $ngpus --distributed-rank 0 --distributed-port 100 --ddp-backend no_c10d \ + --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ + --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 897da4d77..f7f553b6e 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -197,7 +197,7 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then --log-interval 2000 --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 128 \ --valid-subset $valid_subset --max-sentences-valid 256 \ - --distributed-world-size $ngpus --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ @@ -227,7 +227,7 @@ if [ ${stage} -le 6 ] && $use_wordlm; then --log-interval 2000 --log-format simple \ --num-workers 0 --max-tokens 6400 --max-sentences 256 \ --valid-subset $valid_subset --max-sentences-valid 512 \ - --distributed-world-size $ngpus --distributed-port 100 \ + --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $wordlmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ @@ -267,8 +267,8 @@ if [ ${stage} -le 8 ]; then CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ - --valid-subset $valid_subset --max-sentences-valid 64 \ - --distributed-world-size $ngpus --distributed-port 100 --ddp-backend no_c10d \ + --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ + --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ diff --git a/speech_train.py b/speech_train.py index 23c211b91..5ebccfcd3 100755 --- a/speech_train.py +++ b/speech_train.py @@ -21,10 +21,21 @@ from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.utils import import_user_module +fb_pathmgr_registerd = False + def main(args, init_distributed=False): utils.import_user_module(args) + try: + from fairseq.fb_pathmgr import fb_pathmgr + global fb_pathmgr_registerd + if not fb_pathmgr_registerd: + fb_pathmgr.register() + fb_pathmgr_registerd = True + except (ModuleNotFoundError, ImportError): + pass + assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' From 67fdff19494ae4f7566798a8211bf90ea483c0f0 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 18 Oct 2019 22:04:51 -0400 Subject: [PATCH 042/119] change warmup scheduling for ReduceLROnPlateauV2; code adaptation/changes according to the commits on Oct 18, 2019 --- examples/asr_librispeech/run.sh | 2 +- examples/asr_swbd/run.sh | 2 +- examples/asr_wsj/run.sh | 2 +- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 18 ++++++++++-------- speech_train.py | 10 +++------- 5 files changed, 16 insertions(+), 18 deletions(-) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 6cdd3712a..b78e80fa7 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -211,7 +211,7 @@ if [ ${stage} -le 7 ]; then --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 291b01071..c4a8a296d 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -252,7 +252,7 @@ if [ $stage -le 6 ]; then --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 10 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_with_wer \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index f7f553b6e..4c6ad2cea 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -270,7 +270,7 @@ if [ ${stage} -le 8 ]; then --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --min-lr 1e-5 --start-reduce-lr-epoch 11 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer \ diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 6348a313f..449fcb9ea 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -21,14 +21,13 @@ class ReduceLROnPlateauV2(ReduceLROnPlateau): def __init__(self, args, optimizer): super().__init__(args, optimizer) - if args.warmup_updates > 0: - self.warmup_factor = 1. / args.warmup_updates - else: - self.warmup_factor = 1. + self.init_lr = args.init_lr_scale * args.lr[0] if args.warmup_updates > 0 else args.lr[0] + self.warmup_rate = (args.lr[0] - self.init_lr) / args.warmup_updates \ + if args.warmup_updates > 0 else 0. self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=0, factor=args.lr_shrink, - threshold=args.lr_threshold, min_lr=args.min_lr) + threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0]) @staticmethod def add_args(parser): @@ -37,6 +36,10 @@ def add_args(parser): # fmt: off parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', help='warmup the learning rate linearly for the first N updates') + parser.add_argument('--init-lr-scale', default=0.01, type=float, metavar='N', + help='initial learning rate scale during warmup phase; default is 0.01') + parser.add_argument('--final-lr-scale', default=0.01, type=float, metavar='N', + help='final learning rate scale; default to 0.01') parser.add_argument('--start-reduce-lr-epoch', default=0, type=int, metavar='N', help='start to reduce lr from the specified epoch') # fmt: on @@ -44,13 +47,12 @@ def add_args(parser): def step(self, epoch, val_loss=None): if epoch < self.args.start_reduce_lr_epoch: self.lr_scheduler.last_epoch = epoch - self.optimizer.set_lr(self.warmup_factor * self.args.lr[0]) + self.optimizer.set_lr(self.args.lr[0]) return self.optimizer.get_lr() return super().step(epoch, val_loss) def step_update(self, num_updates): """Update the learning rate after each update.""" if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: - self.warmup_factor = num_updates / float(self.args.warmup_updates) - self.optimizer.set_lr(self.warmup_factor * self.args.lr[0]) + self.optimizer.set_lr(self.init_lr + self.warmup_rate * num_updates) return self.optimizer.get_lr() diff --git a/speech_train.py b/speech_train.py index 5ebccfcd3..c0d07aedf 100755 --- a/speech_train.py +++ b/speech_train.py @@ -93,13 +93,9 @@ def main(args, init_distributed=False): train_meter.start() valid_subsets = args.valid_subset.split(',') while ( - (lr >= args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0)) - and ( - epoch_itr.epoch < max_epoch or ( - epoch_itr.epoch == max_epoch - and epoch_itr._next_epoch_itr is not None - ) - ) + lr > args.min_lr + and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch + and epoch_itr._next_epoch_itr is not None)) and trainer.get_num_updates() < max_update ): # train for one epoch From aab85e5603ddeb8510d357b6d01e41a61528b0ff Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 22 Oct 2019 23:32:38 -0400 Subject: [PATCH 043/119] remove warmup code in ReduceLROnPlateauV2 as fariseq just added it; code adaptation/changes according to the commits on Oct 23, 2019 --- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 18 ++---------------- speech_train.py | 3 ++- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 449fcb9ea..54b8ab64b 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -14,17 +14,13 @@ class ReduceLROnPlateauV2(ReduceLROnPlateau): """Decay the LR by a factor every time the validation loss plateaus, starting from the epoch specified as args.start_reduce_lr_epoch. - We also support a warmup phase where we linearly increase the learning rate - from 0 until the configured learning rate (``--lr``). + We also support specifying a final lr which will be kept until the max number + of epochs is reached. """ def __init__(self, args, optimizer): super().__init__(args, optimizer) - self.init_lr = args.init_lr_scale * args.lr[0] if args.warmup_updates > 0 else args.lr[0] - self.warmup_rate = (args.lr[0] - self.init_lr) / args.warmup_updates \ - if args.warmup_updates > 0 else 0. - self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=0, factor=args.lr_shrink, threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0]) @@ -34,10 +30,6 @@ def add_args(parser): """Add arguments to the parser for this LR scheduler.""" ReduceLROnPlateau.add_args(parser) # fmt: off - parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - parser.add_argument('--init-lr-scale', default=0.01, type=float, metavar='N', - help='initial learning rate scale during warmup phase; default is 0.01') parser.add_argument('--final-lr-scale', default=0.01, type=float, metavar='N', help='final learning rate scale; default to 0.01') parser.add_argument('--start-reduce-lr-epoch', default=0, type=int, metavar='N', @@ -50,9 +42,3 @@ def step(self, epoch, val_loss=None): self.optimizer.set_lr(self.args.lr[0]) return self.optimizer.get_lr() return super().step(epoch, val_loss) - - def step_update(self, num_updates): - """Update the learning rate after each update.""" - if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: - self.optimizer.set_lr(self.init_lr + self.warmup_rate * num_updates) - return self.optimizer.get_lr() diff --git a/speech_train.py b/speech_train.py index c0d07aedf..73d103665 100755 --- a/speech_train.py +++ b/speech_train.py @@ -160,9 +160,10 @@ def train(args, trainer, task, epoch_itr): stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) - # ignore the first mini-batch in words-per-second calculation + # ignore the first mini-batch in words-per-second and updates-per-second calculation if i == 0: trainer.get_meter('wps').reset() + trainer.get_meter('ups').reset() num_updates = trainer.get_num_updates() if ( From 3018d200188f876a4f083928725c250c6223480b Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 30 Oct 2019 00:59:18 -0400 Subject: [PATCH 044/119] add gpu.conf for SGE qsub --- examples/asr_librispeech/cmd.sh | 8 ++++---- examples/asr_librispeech/conf/gpu.conf | 10 ++++++++++ examples/asr_swbd/conf/gpu.conf | 10 ++++++++++ examples/asr_wsj/conf/gpu.conf | 10 ++++++++++ 4 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 examples/asr_librispeech/conf/gpu.conf create mode 100644 examples/asr_swbd/conf/gpu.conf create mode 100644 examples/asr_wsj/conf/gpu.conf diff --git a/examples/asr_librispeech/cmd.sh b/examples/asr_librispeech/cmd.sh index b14280b96..9e73f25da 100644 --- a/examples/asr_librispeech/cmd.sh +++ b/examples/asr_librispeech/cmd.sh @@ -10,11 +10,11 @@ # conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, # or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. -#export train_cmd="run.pl --mem 4G" -#export cuda_cmd="run.pl --mem 4G --gpu 1" +#export train_cmd="run.pl --mem 10G" +#export cuda_cmd="run.pl --mem 10G --gpu 1" #export decode_cmd="run.pl --mem 4G" # JHU setup -export train_cmd="queue.pl --mem 4G" -export cuda_cmd="queue.pl --mem 4G --gpu 1 --config conf/gpu.conf" +export train_cmd="queue.pl --mem 10G" +export cuda_cmd="queue.pl --mem 10G --gpu 1 --config conf/gpu.conf" export decode_cmd="queue.pl --mem 4G" diff --git a/examples/asr_librispeech/conf/gpu.conf b/examples/asr_librispeech/conf/gpu.conf new file mode 100644 index 000000000..5cc94adf2 --- /dev/null +++ b/examples/asr_librispeech/conf/gpu.conf @@ -0,0 +1,10 @@ +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* +option mem=* -l mem_free=$0,ram_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -pe smp $0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 +option gpu=* -l 'hostname=c*,gpu=$0' -q g.q diff --git a/examples/asr_swbd/conf/gpu.conf b/examples/asr_swbd/conf/gpu.conf new file mode 100644 index 000000000..5cc94adf2 --- /dev/null +++ b/examples/asr_swbd/conf/gpu.conf @@ -0,0 +1,10 @@ +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* +option mem=* -l mem_free=$0,ram_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -pe smp $0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 +option gpu=* -l 'hostname=c*,gpu=$0' -q g.q diff --git a/examples/asr_wsj/conf/gpu.conf b/examples/asr_wsj/conf/gpu.conf new file mode 100644 index 000000000..5cc94adf2 --- /dev/null +++ b/examples/asr_wsj/conf/gpu.conf @@ -0,0 +1,10 @@ +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* +option mem=* -l mem_free=$0,ram_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -pe smp $0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 +option gpu=* -l 'hostname=c*,gpu=$0' -q g.q From a1b76df8a43a03395443509c68629144c0123da6 Mon Sep 17 00:00:00 2001 From: Shujian2015 Date: Fri, 1 Nov 2019 20:23:13 -0400 Subject: [PATCH 045/119] Fixed error when using fp16 Fixed error when using fp16. Followed example from https://github.com/pytorch/fairseq/blob/master/fairseq/models/lstm.py#L299 --- fairseq/modules/speech_attention.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fairseq/modules/speech_attention.py b/fairseq/modules/speech_attention.py index 114fa1aca..002964eb1 100644 --- a/fairseq/modules/speech_attention.py +++ b/fairseq/modules/speech_attention.py @@ -7,6 +7,7 @@ import torch from torch import nn from torch.nn import Parameter +import torch.nn.functional as F from fairseq import utils @@ -68,15 +69,14 @@ def forward(self, query, value, key_padding_mask=None, state=None): attn_scores = (normed_v * torch.tanh(projected_query + key + \ self.b)).sum(dim=2) # len x bsz else: - attn_scores = v * torch.tanh(projected_query + key).sum(dim=2) + attn_scores = self.v * torch.tanh(projected_query + key).sum(dim=2) if key_padding_mask is not None: attn_scores = attn_scores.float().masked_fill_( key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back - attn_scores = utils.softmax(attn_scores, dim=0, - onnx_trace=self.onnx_trace).type_as(attn_scores) # len x bsz + attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz # sum weighted value. context: bsz x value_dim context = (attn_scores.unsqueeze(2) * value).sum(dim=0) @@ -115,8 +115,7 @@ def forward(self, query, value, key_padding_mask=None, state=None): key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back - attn_scores = utils.softmax(attn_scores, dim=0, - onnx_trace=self.onnx_trace).type_as(attn_scores) # len x bsz + attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz # sum weighted value. context: bsz x value_dim context = (attn_scores.unsqueeze(2) * value).sum(dim=0) From 75581aea290ebe72c9c08b60df88f06aaddae834 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 8 Nov 2019 20:38:36 -0500 Subject: [PATCH 046/119] code adaptation/changes according to the commits on Nov 8, 2019 --- speech_train.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/speech_train.py b/speech_train.py index 73d103665..54d158bb8 100755 --- a/speech_train.py +++ b/speech_train.py @@ -21,21 +21,10 @@ from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.utils import import_user_module -fb_pathmgr_registerd = False - def main(args, init_distributed=False): utils.import_user_module(args) - try: - from fairseq.fb_pathmgr import fb_pathmgr - global fb_pathmgr_registerd - if not fb_pathmgr_registerd: - fb_pathmgr.register() - fb_pathmgr_registerd = True - except (ModuleNotFoundError, ImportError): - pass - assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' @@ -103,6 +92,8 @@ def main(args, init_distributed=False): if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + else: + valid_losses = [None] # only use first validation wer to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) From 0e91c3b431bafd8a4f49bebb064411da2f53263e Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 27 Nov 2019 00:10:41 -0500 Subject: [PATCH 047/119] code adaptation/changes according to the commits on Nov 26, 2019 --- fairseq/data/token_dictionary.py | 7 +++---- fairseq/models/speech_lstm.py | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 978abf92c..75d327bbd 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -66,7 +66,7 @@ def space(self): return self.space_index @classmethod - def load(cls, f, f_non_lang_syms=None, ignore_utf_errors=False): + def load(cls, f, f_non_lang_syms=None): """Loads the dictionary from a text file with the format: ``` @@ -81,14 +81,13 @@ def load(cls, f, f_non_lang_syms=None, ignore_utf_errors=False): Loads non_lang_syms from another text file, if it exists, with one symbol per line """ - d = super().load(f, ignore_utf_errors) + d = super().load(f) d.space_index = d.indices.get(d.space_word, -1) if f_non_lang_syms is not None: assert isinstance(f_non_lang_syms, str) try: - with open(f_non_lang_syms, 'r', encoding='utf-8', - errors='ignore' if ignore_utf_errors else None) as fd: + with open(f_non_lang_syms, 'r', encoding='utf-8') as fd: non_lang_syms = [x.rstrip() for x in fd.readlines()] except FileNotFoundError as fnfe: raise fnfe diff --git a/fairseq/models/speech_lstm.py b/fairseq/models/speech_lstm.py index 3856838ec..d682d46d0 100644 --- a/fairseq/models/speech_lstm.py +++ b/fairseq/models/speech_lstm.py @@ -559,11 +559,13 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, - the decoder's output of shape `(batch, tgt_len, vocab)` - attention weights of shape `(batch, tgt_len, src_len)` """ - x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state) - x = self.output_layer(x) - return x, extra + x, attn_scores = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + return self.output_layer(x), attn_scores - def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): """ Similar to *forward* but only return features. @@ -660,6 +662,10 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta # T x B x C -> B x T x C x = x.transpose(1, 0) + if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: + x = self.additional_fc(x) + x = F.dropout(x, p=self.dropout_out, training=self.training) + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and self.attention is not None and self.need_attn: attn_scores = attn_scores.transpose(0, 2) @@ -669,12 +675,9 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta return x, attn_scores def output_layer(self, features, **kwargs): - """ project features to the vocabulary size.""" + """Project features to the vocabulary size.""" if self.adaptive_softmax is None: # project back to size of vocabulary - if hasattr(self, 'additional_fc'): - features = self.additional_fc(features) - features = F.dropout(features, p=self.dropout_out, training=self.training) if self.share_input_output_embed: return F.linear(features, self.embed_tokens.weight) else: From 7f3a0bce505883c4d4d71783e4b1b297c4a88dd8 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 11 Dec 2019 14:33:14 -0500 Subject: [PATCH 048/119] code adaptation/changes according to the commits on Dec 11, 2019 --- fairseq/data/token_dictionary.py | 40 ++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/fairseq/data/token_dictionary.py b/fairseq/data/token_dictionary.py index 75d327bbd..fc9e6fbef 100644 --- a/fairseq/data/token_dictionary.py +++ b/fairseq/data/token_dictionary.py @@ -4,9 +4,9 @@ # LICENSE file in the root directory of this source tree. import torch - -from fairseq.tokenizer import tokenize_line from fairseq.data import Dictionary, data_utils +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line class TokenDictionary(Dictionary): @@ -14,11 +14,11 @@ class TokenDictionary(Dictionary): def __init__( self, - pad='', - eos='
', - unk='', - bos='', - space='', + pad="", + eos="", + unk="", + bos="", + space="", extra_special_symbols=None, ): self.unk_word, self.pad_word, self.eos_word, self.bos_word, self.space_word = \ @@ -43,7 +43,7 @@ def string(self, tensor, bpe_symbol=None, escape_unk=False): We overwrite this since we would like to also ignore . """ if torch.is_tensor(tensor) and tensor.dim() == 2: - return '\n'.join(self.string(t, bpe_symbol, escape_unk) for t in tensor) + return "\n".join(self.string(t, bpe_symbol, escape_unk) for t in tensor) def token_string(i): if i == self.unk(): @@ -51,10 +51,18 @@ def token_string(i): else: return self[i] - if hasattr(self, 'bos_index'): - sent = ' '.join(token_string(i) for i in tensor if (i != self.eos()) and (i != self.bos()) and (i != self.pad())) + if hasattr(self, "bos_index"): + sent = " ".join( + token_string(i) + for i in tensor + if (i != self.eos()) and (i != self.bos()) and (i != self.pad()) + ) else: - sent = ' '.join(token_string(i) for i in tensor if i != self.eos() and i != self.pad()) + sent = " ".join( + token_string(i) + for i in tensor + if (i != self.eos()) and (i != self.pad()) + ) return data_utils.process_bpe_symbol(sent, bpe_symbol) def bos(self): @@ -87,17 +95,19 @@ def load(cls, f, f_non_lang_syms=None): if f_non_lang_syms is not None: assert isinstance(f_non_lang_syms, str) try: - with open(f_non_lang_syms, 'r', encoding='utf-8') as fd: + with PathManager.open(f_non_lang_syms, "r", encoding="utf-8") as fd: non_lang_syms = [x.rstrip() for x in fd.readlines()] except FileNotFoundError as fnfe: raise fnfe except UnicodeError: - raise Exception("Incorrect encoding detected in {}, please " - "rebuild the dataset".format(f)) + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) for sym in non_lang_syms: assert d.index(sym) != d.unk(), \ - '{} in {} is not in the dictionary'.format(sym, f_non_lang_syms) + "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) d.non_lang_syms = non_lang_syms return d From 5744fb994fa5159cebe4fdc3c8f759e3942c342b Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 15 Dec 2019 17:51:29 -0500 Subject: [PATCH 049/119] allows text2vocabulary.py to accept an existing vocabulry with its first column as words --- examples/asr_swbd/run.sh | 4 +- speech_tools/text2vocabulary.py | 73 ++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index c4a8a296d..4b9b0aa7d 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -173,7 +173,7 @@ if [ $stage -le 2 ]; then python3 ../../scripts/spm_encode.py --model=${sentencepiece_model}.model --output_format=piece | \ cat $lmdatadir/$train_set.tokens - > $lmdatadir/train.tokens - echo "$0: making a dictionary with swbd+fisher text" + echo "$0: making a subword dictionary with swbd+fisher text" cat $lmdatadir/train.tokens | tr " " "\n" | grep -v -e "^\s*$" | sort | \ uniq -c | awk '{print $2,$1}' > $dict wc -l $dict @@ -183,7 +183,7 @@ lmdict=$dict if [ $stage -le 3 ]; then echo "Stage 3: Text Binarization for subword LM Training" mkdir -p $lmdatadir/logs - for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done + test_paths= && for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/logs/preprocess.log \ python3 ../../preprocess.py --task language_modeling_for_asr \ diff --git a/speech_tools/text2vocabulary.py b/speech_tools/text2vocabulary.py index 9168c470c..d1a90b83a 100755 --- a/speech_tools/text2vocabulary.py +++ b/speech_tools/text2vocabulary.py @@ -22,10 +22,14 @@ def get_parser(): parser.add_argument('--exclude', type=str, default=None, help='space separated, list of excluding words, ' 'e.g., etc.') + parser.add_argument('--vocab', type=str, default=None, + help='path to the vocabulary file. If not None, calculate' + 'OOV stats with the provided vocabulary and output the ' + 'same vocabulary with word counts') parser.add_argument('--valid-text', type=str, default=None, - help='path to the validation text') + help='path to the validation text file') parser.add_argument('--test-text', type=str, default=None, - help='path to the test text') + help='colon separated paths to the test text file(s)') parser.add_argument('text_files', nargs='*', help='input text files') # fmt: on @@ -47,28 +51,38 @@ def main(args): counter.update(tokens) total_count = sum(counter.values()) - most_common = counter.most_common(args.vocabsize) - cutoff_point = 0 invocab_count = 0 - for elem in most_common: - if elem[1] < args.cutoff: - break - invocab_count += elem[1] - cutoff_point += 1 - cutoff_freq = most_common[cutoff_point - 1][1] - most_common = most_common[:cutoff_point] + if args.vocab is None: + most_common = counter.most_common(args.vocabsize) + cutoff_point = 0 + for elem in most_common: + if elem[1] < args.cutoff: + break + invocab_count += elem[1] + cutoff_point += 1 + cutoff_freq = most_common[cutoff_point - 1][1] + most_common = most_common[:cutoff_point] + vocab_set = set(list(zip(*most_common))[0]) + else: + print('using the provided vocabulary:', file=sys.stderr) + with open(args.vocab, 'r', encoding='utf-8') as f: + vocab_set = set([line.rstrip().split()[0] for line in f]) + most_common = [] + for word in vocab_set: + invocab_count += counter[word] + most_common.append((word, counter[word])) + + # words in vocabulary are lexically sorted + for w, c in sorted(most_common, key=lambda x: x[0]): + print('{} {:d}'.format(w, c)) oov_rate = 1. - float(invocab_count) / total_count print('training set:', file=sys.stderr) print(' total #tokens={:d}'.format(total_count), file=sys.stderr) print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) - print(' cutoff frequency={:d}'.format(cutoff_freq), file=sys.stderr) + if args.vocab is None: + print(' cutoff frequency={:d}'.format(cutoff_freq), file=sys.stderr) - # words in vocabulary are lexically sorted - for w, c in sorted(most_common, key=lambda x: x[0]): - print('{} {:d}'.format(w, c)) - - vocab_set = set(list(zip(*most_common))[0]) if args.valid_text is not None: total_count = 0 invocab_count = 0 @@ -84,18 +98,19 @@ def main(args): print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) if args.test_text is not None: - total_count = 0 - invocab_count = 0 - with open(args.test_text, 'r', encoding='utf-8') as f: - for line in f: - tokens = line.rstrip().split()[args.skip_ncols:] - tokens = [tok for tok in tokens if tok not in exclude] - total_count += len(tokens) - invocab_count += len([tok for tok in tokens if tok in vocab_set]) - oov_rate = 1. - float(invocab_count) / total_count - print('test set:', file=sys.stderr) - print(' total #tokens={:d}'.format(total_count), file=sys.stderr) - print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) + for k, path in enumerate(args.test_text.split(':')): + total_count = 0 + invocab_count = 0 + with open(path, 'r', encoding='utf-8') as f: + for line in f: + tokens = line.rstrip().split()[args.skip_ncols:] + tokens = [tok for tok in tokens if tok not in exclude] + total_count += len(tokens) + invocab_count += len([tok for tok in tokens if tok in vocab_set]) + oov_rate = 1. - float(invocab_count) / total_count + print('test set{}:'.format(k) if k > 0 else 'test set:', file=sys.stderr) + print(' total #tokens={:d}'.format(total_count), file=sys.stderr) + print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) if __name__ == '__main__': From 461703677f49ab099f5026d3250ee034a576d4ea Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 16 Dec 2019 17:43:37 -0500 Subject: [PATCH 050/119] re-organize the codebase to isolate espresso from fairseq --- espresso/__init__.py | 12 ++++++++++ espresso/criterions/__init__.py | 14 +++++++++++ .../criterions/cross_entropy_with_wer.py | 8 ++++--- .../label_smoothed_cross_entropy_with_wer.py | 8 ++++--- espresso/data/__init__.py | 17 ++++++++++++++ {fairseq => espresso}/data/scp_dataset.py | 0 {fairseq => espresso}/data/speech_dataset.py | 5 ++-- .../data/token_dictionary.py | 0 .../espresso_logo.png | Bin espresso/models/__init__.py | 14 +++++++++++ .../models/external_language_model.py | 7 +++--- {fairseq => espresso}/models/speech_fconv.py | 10 ++++---- {fairseq => espresso}/models/speech_lstm.py | 7 +++--- .../models/speech_transformer.py | 10 ++++---- .../tensorized_lookahead_language_model.py | 8 +++---- espresso/modules/__init__.py | 12 ++++++++++ .../modules/speech_attention.py | 0 espresso/optim/__init__.py | 14 +++++++++++ espresso/optim/lr_scheduler/__init__.py | 14 +++++++++++ .../lr_scheduler/reduce_lr_on_plateau_v2.py | 4 ++-- .../speech_recognize.py | 10 ++++---- speech_train.py => espresso/speech_train.py | 0 espresso/tasks/__init__.py | 14 +++++++++++ .../tasks/language_modeling_for_asr.py | 5 ++-- .../tasks/speech_recognition.py | 7 ++++-- {speech_tools => espresso/tools}/.gitignore | 0 {speech_tools => espresso/tools}/Makefile | 0 {speech_tools => espresso/tools}/__init__.py | 0 .../tools}/compute_wer.py | 2 +- {speech_tools => espresso/tools}/dump.sh | 0 .../tools}/tensorized_prefix_tree.py | 7 +++--- .../tools}/text2token.py | 0 .../tools}/text2vocabulary.py | 4 ++-- {speech_tools => espresso/tools}/utils.py | 3 ++- {fairseq => espresso/tools}/wer.py | 2 +- examples/asr_librispeech/local/data_prep.sh | 2 +- .../local/download_and_untar.sh | 2 +- examples/asr_librispeech/path.sh | 6 ++--- examples/asr_librispeech/run.sh | 10 ++++---- examples/asr_librispeech/steps | 2 +- examples/asr_librispeech/utils | 2 +- examples/asr_swbd/local/MSU_single_letter.txt | 2 +- examples/asr_swbd/local/dict.patch | 2 +- examples/asr_swbd/local/eval2000_data_prep.sh | 2 +- examples/asr_swbd/local/extend_segments.pl | 2 +- examples/asr_swbd/local/fisher_map_words.pl | 2 +- .../asr_swbd/local/format_acronyms_dict.py | 2 +- examples/asr_swbd/local/map_acronyms_ctm.py | 2 +- .../local/map_acronyms_transcripts.py | 2 +- examples/asr_swbd/local/rt03_data_prep.sh | 2 +- .../asr_swbd/local/swbd1_data_download.sh | 2 +- examples/asr_swbd/local/swbd1_data_prep.sh | 2 +- .../asr_swbd/local/swbd1_fix_speakerid.pl | 2 +- examples/asr_swbd/local/swbd1_map_words.pl | 2 +- examples/asr_swbd/local/swbd1_prepare_dict.sh | 2 +- examples/asr_swbd/path.sh | 6 ++--- examples/asr_swbd/run.sh | 22 +++++++++--------- examples/asr_swbd/steps | 2 +- examples/asr_swbd/utils | 2 +- examples/asr_wsj/local/find_transcripts.pl | 2 +- examples/asr_wsj/local/flist2scp.pl | 2 +- examples/asr_wsj/local/ndx2flist.pl | 2 +- .../asr_wsj/local/normalize_transcript.pl | 2 +- examples/asr_wsj/local/wsj_data_prep.sh | 2 +- examples/asr_wsj/path.sh | 6 ++--- examples/asr_wsj/run.sh | 16 ++++++------- examples/asr_wsj/steps | 2 +- examples/asr_wsj/utils | 2 +- fairseq/data/__init__.py | 9 ------- tests/espresso/__init__.py | 0 tests/{ => espresso}/test_speech_dataset.py | 10 +++++--- tests/{ => espresso}/test_speech_utils.py | 4 ++-- 72 files changed, 238 insertions(+), 123 deletions(-) create mode 100644 espresso/__init__.py create mode 100644 espresso/criterions/__init__.py rename {fairseq => espresso}/criterions/cross_entropy_with_wer.py (98%) rename {fairseq => espresso}/criterions/label_smoothed_cross_entropy_with_wer.py (98%) create mode 100644 espresso/data/__init__.py rename {fairseq => espresso}/data/scp_dataset.py (100%) rename {fairseq => espresso}/data/speech_dataset.py (98%) rename {fairseq => espresso}/data/token_dictionary.py (100%) rename espresso_logo.png => espresso/espresso_logo.png (100%) create mode 100644 espresso/models/__init__.py rename {fairseq => espresso}/models/external_language_model.py (99%) rename {fairseq => espresso}/models/speech_fconv.py (99%) rename {fairseq => espresso}/models/speech_lstm.py (99%) rename {fairseq => espresso}/models/speech_transformer.py (98%) rename {fairseq => espresso}/models/tensorized_lookahead_language_model.py (98%) create mode 100644 espresso/modules/__init__.py rename {fairseq => espresso}/modules/speech_attention.py (100%) create mode 100644 espresso/optim/__init__.py create mode 100644 espresso/optim/lr_scheduler/__init__.py rename {fairseq => espresso}/optim/lr_scheduler/reduce_lr_on_plateau_v2.py (91%) rename speech_recognize.py => espresso/speech_recognize.py (96%) rename speech_train.py => espresso/speech_train.py (100%) create mode 100644 espresso/tasks/__init__.py rename {fairseq => espresso}/tasks/language_modeling_for_asr.py (97%) rename {fairseq => espresso}/tasks/speech_recognition.py (99%) rename {speech_tools => espresso/tools}/.gitignore (100%) rename {speech_tools => espresso/tools}/Makefile (100%) rename {speech_tools => espresso/tools}/__init__.py (100%) rename {speech_tools => espresso/tools}/compute_wer.py (98%) rename {speech_tools => espresso/tools}/dump.sh (100%) rename {speech_tools => espresso/tools}/tensorized_prefix_tree.py (97%) rename {speech_tools => espresso/tools}/text2token.py (100%) rename {speech_tools => espresso/tools}/text2vocabulary.py (98%) rename {speech_tools => espresso/tools}/utils.py (99%) rename {fairseq => espresso/tools}/wer.py (99%) create mode 100644 tests/espresso/__init__.py rename tests/{ => espresso}/test_speech_dataset.py (97%) rename tests/{ => espresso}/test_speech_utils.py (99%) diff --git a/espresso/__init__.py b/espresso/__init__.py new file mode 100644 index 000000000..98cb3449e --- /dev/null +++ b/espresso/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import espresso.criterions # noqa +import espresso.models # noqa +import espresso.modules # noqa +import espresso.optim # noqa +import espresso.optim.lr_scheduler # noqa +import espresso.tasks # noqa diff --git a/espresso/criterions/__init__.py b/espresso/criterions/__init__.py new file mode 100644 index 000000000..3edbc58f4 --- /dev/null +++ b/espresso/criterions/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the criterions/ directory +for file in os.listdir(os.path.dirname(__file__)): + if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): + criterion_name = file[: file.find(".py")] + importlib.import_module("espresso.criterions." + criterion_name) diff --git a/fairseq/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py similarity index 98% rename from fairseq/criterions/cross_entropy_with_wer.py rename to espresso/criterions/cross_entropy_with_wer.py index 92d5b602f..ad3ebd14e 100644 --- a/fairseq/criterions/cross_entropy_with_wer.py +++ b/espresso/criterions/cross_entropy_with_wer.py @@ -7,13 +7,15 @@ import torch import torch.nn.functional as F -from fairseq import utils, wer +from fairseq import utils from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder from fairseq.options import eval_str_list -from . import FairseqCriterion, register_criterion -from .cross_entropy import CrossEntropyCriterion +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion + +from espresso.tools import wer @register_criterion('cross_entropy_with_wer') diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py similarity index 98% rename from fairseq/criterions/label_smoothed_cross_entropy_with_wer.py rename to espresso/criterions/label_smoothed_cross_entropy_with_wer.py index ca3369458..dca7c1e40 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -6,13 +6,15 @@ import numpy as np import torch -from fairseq import utils, wer +from fairseq import utils from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder from fairseq.options import eval_str_list -from . import FairseqCriterion, register_criterion -from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion + +from espresso.tools import wer def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py new file mode 100644 index 000000000..b8bfbc584 --- /dev/null +++ b/espresso/data/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .token_dictionary import TokenDictionary +from .scp_dataset import ScpDataset, ScpCachedDataset, ScpInMemoryDataset, TokenTextDataset +from .speech_dataset import SpeechDataset + +__all__ = [ + 'ScpDataset', + 'ScpCachedDataset', + 'ScpInMemoryDataset', + 'TokenDictionary', + 'TokenTextDataset', + 'SpeechDataset', +] diff --git a/fairseq/data/scp_dataset.py b/espresso/data/scp_dataset.py similarity index 100% rename from fairseq/data/scp_dataset.py rename to espresso/data/scp_dataset.py diff --git a/fairseq/data/speech_dataset.py b/espresso/data/speech_dataset.py similarity index 98% rename from fairseq/data/speech_dataset.py rename to espresso/data/speech_dataset.py index d5533d23a..10b5cab34 100644 --- a/fairseq/data/speech_dataset.py +++ b/espresso/data/speech_dataset.py @@ -6,8 +6,9 @@ import numpy as np import torch -from . import data_utils, FairseqDataset -import speech_tools.utils as speech_utils +from fairseq.data import data_utils, FairseqDataset + +import espresso.tools.utils as speech_utils def collate( diff --git a/fairseq/data/token_dictionary.py b/espresso/data/token_dictionary.py similarity index 100% rename from fairseq/data/token_dictionary.py rename to espresso/data/token_dictionary.py diff --git a/espresso_logo.png b/espresso/espresso_logo.png similarity index 100% rename from espresso_logo.png rename to espresso/espresso_logo.png diff --git a/espresso/models/__init__.py b/espresso/models/__init__.py new file mode 100644 index 000000000..928f3caea --- /dev/null +++ b/espresso/models/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the models/ directory +for file in os.listdir(os.path.dirname(__file__)): + if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): + model_name = file[:file.find(".py")] + importlib.import_module("espresso.models." + model_name) diff --git a/fairseq/models/external_language_model.py b/espresso/models/external_language_model.py similarity index 99% rename from fairseq/models/external_language_model.py rename to espresso/models/external_language_model.py index a1d44f6de..1805f3d55 100644 --- a/fairseq/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -9,11 +9,10 @@ import torch.nn.functional as F from fairseq import options, utils -from fairseq.data import TokenDictionary +from fairseq.models import FairseqIncrementalDecoder, FairseqLanguageModel -from . import FairseqIncrementalDecoder, FairseqLanguageModel - -from speech_tools.utils import tokenize, lexical_prefix_tree +from espresso.data import TokenDictionary +from espresso.tools.utils import tokenize, lexical_prefix_tree def _clone_cached_state(cached_state): diff --git a/fairseq/models/speech_fconv.py b/espresso/models/speech_fconv.py similarity index 99% rename from fairseq/models/speech_fconv.py rename to espresso/models/speech_fconv.py index 64a28ae5e..36e682adf 100644 --- a/fairseq/models/speech_fconv.py +++ b/espresso/models/speech_fconv.py @@ -13,11 +13,7 @@ register_model, register_model_architecture, ) -from fairseq.modules import GradMultiply - -from .speech_lstm import ConvBNReLU - -from .fconv import ( +from fairseq.models.fconv import ( ConvTBC, FConvModel, FConvEncoder, @@ -25,8 +21,10 @@ Linear, extend_conv_spec, ) +from fairseq.modules import GradMultiply -import speech_tools.utils as speech_utils +from espresso.models.speech_lstm import ConvBNReLU +import espresso.tools.utils as speech_utils @register_model('speech_fconv') diff --git a/fairseq/models/speech_lstm.py b/espresso/models/speech_lstm.py similarity index 99% rename from fairseq/models/speech_lstm.py rename to espresso/models/speech_lstm.py index d682d46d0..df1b5b215 100644 --- a/fairseq/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -23,10 +23,11 @@ LSTMCell, Linear, ) -from fairseq.modules import AdaptiveSoftmax, speech_attention -from fairseq.tasks.speech_recognition import SpeechRecognitionTask +from fairseq.modules import AdaptiveSoftmax -import speech_tools.utils as speech_utils +from espresso.modules import speech_attention +from espresso.tasks.speech_recognition import SpeechRecognitionTask +import espresso.tools.utils as speech_utils @register_model('speech_lstm') diff --git a/fairseq/models/speech_transformer.py b/espresso/models/speech_transformer.py similarity index 98% rename from fairseq/models/speech_transformer.py rename to espresso/models/speech_transformer.py index 91996b752..3a527ea3e 100644 --- a/fairseq/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -14,11 +14,7 @@ register_model, register_model_architecture, ) -from fairseq.modules import LayerNorm - -from .speech_lstm import ConvBNReLU - -from .transformer import ( +from fairseq.models.transformer import ( Embedding, Linear, TransformerModel, @@ -26,8 +22,10 @@ TransformerDecoder, TransformerEncoderLayer, ) +from fairseq.modules import LayerNorm -import speech_tools.utils as speech_utils +from espresso.models.speech_lstm import ConvBNReLU +import espresso.tools.utils as speech_utils DEFAULT_MAX_SOURCE_POSITIONS = 9999 DEFAULT_MAX_TARGET_POSITIONS = 999 diff --git a/fairseq/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py similarity index 98% rename from fairseq/models/tensorized_lookahead_language_model.py rename to espresso/models/tensorized_lookahead_language_model.py index 86280c221..b9e39d00d 100644 --- a/fairseq/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -7,12 +7,12 @@ import torch from fairseq.models import FairseqLanguageModel, FairseqIncrementalDecoder -from fairseq.models.external_language_model import RawOutExternalLanguageModelBase -from fairseq.data import TokenDictionary from fairseq import utils -from speech_tools.tensorized_prefix_tree import TensorizedPrefixTree -from speech_tools.utils import tokenize +from espresso.data import TokenDictionary +from espresso.models.external_language_model import RawOutExternalLanguageModelBase +from espresso.tools.tensorized_prefix_tree import TensorizedPrefixTree +from espresso.tools.utils import tokenize def _clone_cached_state(cached_state): diff --git a/espresso/modules/__init__.py b/espresso/modules/__init__.py new file mode 100644 index 000000000..1e6b35acd --- /dev/null +++ b/espresso/modules/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .speech_attention import BahdanauAttention, LuongAttention + + +__all__ = [ + 'BahdanauAttention', + 'LuongAttention', +] diff --git a/fairseq/modules/speech_attention.py b/espresso/modules/speech_attention.py similarity index 100% rename from fairseq/modules/speech_attention.py rename to espresso/modules/speech_attention.py diff --git a/espresso/optim/__init__.py b/espresso/optim/__init__.py new file mode 100644 index 000000000..f922fa16e --- /dev/null +++ b/espresso/optim/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the optim/ directory +for file in os.listdir(os.path.dirname(__file__)): + if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): + module = file[:file.find(".py")] + importlib.import_module("espresso.optim." + module) diff --git a/espresso/optim/lr_scheduler/__init__.py b/espresso/optim/lr_scheduler/__init__.py new file mode 100644 index 000000000..a67e46579 --- /dev/null +++ b/espresso/optim/lr_scheduler/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the optim/lr_scheduler/ directory +for file in os.listdir(os.path.dirname(__file__)): + if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): + module = file[:file.find(".py")] + importlib.import_module("espresso.optim.lr_scheduler." + module) diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py similarity index 91% rename from fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py rename to espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 54b8ab64b..2435f3ce1 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -5,8 +5,8 @@ import torch.optim.lr_scheduler -from . import FairseqLRScheduler, register_lr_scheduler -from .reduce_lr_on_plateau import ReduceLROnPlateau +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler +from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateau @register_lr_scheduler('reduce_lr_on_plateau_v2') diff --git a/speech_recognize.py b/espresso/speech_recognize.py similarity index 96% rename from speech_recognize.py rename to espresso/speech_recognize.py index e40248edf..e198990c1 100755 --- a/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -12,13 +12,15 @@ import torch -from fairseq import wer, checkpoint_utils, options, progress_bar, tasks, utils +from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel -from fairseq.models.external_language_model import MultiLevelLanguageModel -from fairseq.models.tensorized_lookahead_language_model import TensorizedLookaheadLanguageModel from fairseq.utils import import_user_module -from speech_tools.utils import plot_attention, sequence_mask + +from espresso.models.external_language_model import MultiLevelLanguageModel +from espresso.models.tensorized_lookahead_language_model import TensorizedLookaheadLanguageModel +from espresso.tools import wer +from espresso.tools.utils import plot_attention, sequence_mask def main(args): diff --git a/speech_train.py b/espresso/speech_train.py similarity index 100% rename from speech_train.py rename to espresso/speech_train.py diff --git a/espresso/tasks/__init__.py b/espresso/tasks/__init__.py new file mode 100644 index 000000000..6739bb677 --- /dev/null +++ b/espresso/tasks/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the tasks/ directory +for file in os.listdir(os.path.dirname(__file__)): + if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): + task_name = file[:file.find(".py")] + importlib.import_module("espresso.tasks." + task_name) diff --git a/fairseq/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py similarity index 97% rename from fairseq/tasks/language_modeling_for_asr.py rename to espresso/tasks/language_modeling_for_asr.py index b0b96bd3a..8547423c3 100644 --- a/fairseq/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -8,10 +8,11 @@ import os from fairseq import tokenizer -from fairseq.data import TokenDictionary from fairseq.tasks import register_task -from .language_modeling import LanguageModelingTask +from fairseq.tasks.language_modeling import LanguageModelingTask + +from espresso.data import TokenDictionary @register_task("language_modeling_for_asr") diff --git a/fairseq/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py similarity index 99% rename from fairseq/tasks/speech_recognition.py rename to espresso/tasks/speech_recognition.py index 4d5df3388..8e7fb64e3 100644 --- a/fairseq/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -13,14 +13,17 @@ from fairseq.data import ( ConcatDataset, data_utils, +) + +from fairseq.tasks import FairseqTask, register_task + +from espresso.data import ( ScpCachedDataset, SpeechDataset, TokenDictionary, TokenTextDataset, ) -from . import FairseqTask, register_task - @register_task('speech_recognition') class SpeechRecognitionTask(FairseqTask): diff --git a/speech_tools/.gitignore b/espresso/tools/.gitignore similarity index 100% rename from speech_tools/.gitignore rename to espresso/tools/.gitignore diff --git a/speech_tools/Makefile b/espresso/tools/Makefile similarity index 100% rename from speech_tools/Makefile rename to espresso/tools/Makefile diff --git a/speech_tools/__init__.py b/espresso/tools/__init__.py similarity index 100% rename from speech_tools/__init__.py rename to espresso/tools/__init__.py diff --git a/speech_tools/compute_wer.py b/espresso/tools/compute_wer.py similarity index 98% rename from speech_tools/compute_wer.py rename to espresso/tools/compute_wer.py index 05003fb9c..62a44e423 100755 --- a/speech_tools/compute_wer.py +++ b/espresso/tools/compute_wer.py @@ -8,7 +8,7 @@ import sys, re from collections import Counter -from utils import edit_distance +from espresso.tools.utils import edit_distance def get_parser(): diff --git a/speech_tools/dump.sh b/espresso/tools/dump.sh similarity index 100% rename from speech_tools/dump.sh rename to espresso/tools/dump.sh diff --git a/speech_tools/tensorized_prefix_tree.py b/espresso/tools/tensorized_prefix_tree.py similarity index 97% rename from speech_tools/tensorized_prefix_tree.py rename to espresso/tools/tensorized_prefix_tree.py index 33d6f3728..77bf37697 100644 --- a/speech_tools/tensorized_prefix_tree.py +++ b/espresso/tools/tensorized_prefix_tree.py @@ -5,11 +5,12 @@ import os, re import numpy as np +from typing import * + import torch -from typing import * -from fairseq.data import TokenDictionary -from speech_tools.utils import lexical_prefix_tree +from espresso.data import TokenDictionary +from espresso.tools.utils import lexical_prefix_tree class TensorizedPrefixTree: diff --git a/speech_tools/text2token.py b/espresso/tools/text2token.py similarity index 100% rename from speech_tools/text2token.py rename to espresso/tools/text2token.py diff --git a/speech_tools/text2vocabulary.py b/espresso/tools/text2vocabulary.py similarity index 98% rename from speech_tools/text2vocabulary.py rename to espresso/tools/text2vocabulary.py index d1a90b83a..ee3d314c0 100755 --- a/speech_tools/text2vocabulary.py +++ b/espresso/tools/text2vocabulary.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse -import sys +import os, sys from collections import Counter @@ -98,7 +98,7 @@ def main(args): print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) if args.test_text is not None: - for k, path in enumerate(args.test_text.split(':')): + for k, path in enumerate(args.test_text.split(os.pathsep)): total_count = 0 invocab_count = 0 with open(path, 'r', encoding='utf-8') as f: diff --git a/speech_tools/utils.py b/espresso/tools/utils.py similarity index 99% rename from speech_tools/utils.py rename to espresso/tools/utils.py index d5a4bad74..cdf124d7f 100644 --- a/speech_tools/utils.py +++ b/espresso/tools/utils.py @@ -11,7 +11,8 @@ import torch from fairseq import utils -from fairseq.data import TokenDictionary + +from espresso.data import TokenDictionary def tokenize(sent, space='', non_lang_syms=None): diff --git a/fairseq/wer.py b/espresso/tools/wer.py similarity index 99% rename from fairseq/wer.py rename to espresso/tools/wer.py index 52aa2e5b6..213fe9d4a 100644 --- a/fairseq/wer.py +++ b/espresso/tools/wer.py @@ -7,7 +7,7 @@ from collections import Counter, OrderedDict -import speech_tools.utils as speech_utils +import espresso.tools.utils as speech_utils class Scorer(object): diff --git a/examples/asr_librispeech/local/data_prep.sh b/examples/asr_librispeech/local/data_prep.sh index 3000aeaca..c670b1b56 120000 --- a/examples/asr_librispeech/local/data_prep.sh +++ b/examples/asr_librispeech/local/data_prep.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/librispeech/s5/local/data_prep.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/librispeech/s5/local/data_prep.sh \ No newline at end of file diff --git a/examples/asr_librispeech/local/download_and_untar.sh b/examples/asr_librispeech/local/download_and_untar.sh index 4edb356c0..d258978c0 120000 --- a/examples/asr_librispeech/local/download_and_untar.sh +++ b/examples/asr_librispeech/local/download_and_untar.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/librispeech/s5/local/download_and_untar.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/librispeech/s5/local/download_and_untar.sh \ No newline at end of file diff --git a/examples/asr_librispeech/path.sh b/examples/asr_librispeech/path.sh index 19308a610..863f5de3e 100644 --- a/examples/asr_librispeech/path.sh +++ b/examples/asr_librispeech/path.sh @@ -1,5 +1,5 @@ MAIN_ROOT=$PWD/../.. -KALDI_ROOT=$MAIN_ROOT/speech_tools/kaldi +KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi # BEGIN from kaldi path.sh [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh @@ -10,6 +10,6 @@ export LC_ALL=C # END export PATH=~/anaconda3/bin:$PATH -export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH -export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH +export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index b78e80fa7..1082617ed 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -147,7 +147,7 @@ if [ ${stage} -le 4 ]; then for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/logs/preprocess.log \ - python3 ../../preprocess.py --task language_modeling_for_asr \ + python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -167,7 +167,7 @@ if [ ${stage} -le 5 ]; then mkdir -p $lmdir/logs log_file=$lmdir/logs/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval 8000 --log-format simple \ --num-workers 0 --max-tokens 32000 --max-sentences 1024 --curriculum 1 \ @@ -188,7 +188,7 @@ if [ ${stage} -le 6 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log - python3 ../../eval_lm.py $lmdatadir --cpu \ + python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -205,7 +205,7 @@ if [ ${stage} -le 7 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 --user-dir espresso \ --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ @@ -237,7 +237,7 @@ if [ ${stage} -le 8 ]; then feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --max-tokens 15000 --max-sentences 24 --num-shards 1 --shard-id 0 \ + --user-dir espresso --max-tokens 15000 --max-sentences 24 --num-shards 1 --shard-id 0 \ --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ diff --git a/examples/asr_librispeech/steps b/examples/asr_librispeech/steps index ec9b528ac..5871b3f9c 120000 --- a/examples/asr_librispeech/steps +++ b/examples/asr_librispeech/steps @@ -1 +1 @@ -../../speech_tools/kaldi/egs/wsj/s5/steps \ No newline at end of file +../../espresso/tools/kaldi/egs/wsj/s5/steps \ No newline at end of file diff --git a/examples/asr_librispeech/utils b/examples/asr_librispeech/utils index ea44d93b9..bc8958e91 120000 --- a/examples/asr_librispeech/utils +++ b/examples/asr_librispeech/utils @@ -1 +1 @@ -../../speech_tools/kaldi/egs/wsj/s5/utils \ No newline at end of file +../../espresso/tools/kaldi/egs/wsj/s5/utils \ No newline at end of file diff --git a/examples/asr_swbd/local/MSU_single_letter.txt b/examples/asr_swbd/local/MSU_single_letter.txt index 9b034a146..6a2170456 120000 --- a/examples/asr_swbd/local/MSU_single_letter.txt +++ b/examples/asr_swbd/local/MSU_single_letter.txt @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/MSU_single_letter.txt \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/MSU_single_letter.txt \ No newline at end of file diff --git a/examples/asr_swbd/local/dict.patch b/examples/asr_swbd/local/dict.patch index e2ead1dcf..17bf67e3a 120000 --- a/examples/asr_swbd/local/dict.patch +++ b/examples/asr_swbd/local/dict.patch @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/dict.patch \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/dict.patch \ No newline at end of file diff --git a/examples/asr_swbd/local/eval2000_data_prep.sh b/examples/asr_swbd/local/eval2000_data_prep.sh index 179705396..e7520f22f 120000 --- a/examples/asr_swbd/local/eval2000_data_prep.sh +++ b/examples/asr_swbd/local/eval2000_data_prep.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/eval2000_data_prep.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/eval2000_data_prep.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/extend_segments.pl b/examples/asr_swbd/local/extend_segments.pl index 0ff7e3a1a..636ed40e0 120000 --- a/examples/asr_swbd/local/extend_segments.pl +++ b/examples/asr_swbd/local/extend_segments.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/extend_segments.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/extend_segments.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/fisher_map_words.pl b/examples/asr_swbd/local/fisher_map_words.pl index 0b8445fc0..dd3afe1b9 120000 --- a/examples/asr_swbd/local/fisher_map_words.pl +++ b/examples/asr_swbd/local/fisher_map_words.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/fisher_map_words.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/fisher_map_words.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/format_acronyms_dict.py b/examples/asr_swbd/local/format_acronyms_dict.py index c88fb9578..ea8e70380 120000 --- a/examples/asr_swbd/local/format_acronyms_dict.py +++ b/examples/asr_swbd/local/format_acronyms_dict.py @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/format_acronyms_dict.py \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/format_acronyms_dict.py \ No newline at end of file diff --git a/examples/asr_swbd/local/map_acronyms_ctm.py b/examples/asr_swbd/local/map_acronyms_ctm.py index 47c775d1c..eac5152ba 120000 --- a/examples/asr_swbd/local/map_acronyms_ctm.py +++ b/examples/asr_swbd/local/map_acronyms_ctm.py @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/map_acronyms_ctm.py \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/map_acronyms_ctm.py \ No newline at end of file diff --git a/examples/asr_swbd/local/map_acronyms_transcripts.py b/examples/asr_swbd/local/map_acronyms_transcripts.py index 9d1b9c8b7..5e827facb 120000 --- a/examples/asr_swbd/local/map_acronyms_transcripts.py +++ b/examples/asr_swbd/local/map_acronyms_transcripts.py @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/map_acronyms_transcripts.py \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/map_acronyms_transcripts.py \ No newline at end of file diff --git a/examples/asr_swbd/local/rt03_data_prep.sh b/examples/asr_swbd/local/rt03_data_prep.sh index 35e8bb102..49f60f4f0 120000 --- a/examples/asr_swbd/local/rt03_data_prep.sh +++ b/examples/asr_swbd/local/rt03_data_prep.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/rt03_data_prep.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/rt03_data_prep.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_data_download.sh b/examples/asr_swbd/local/swbd1_data_download.sh index 676f5e0b4..123e98a49 120000 --- a/examples/asr_swbd/local/swbd1_data_download.sh +++ b/examples/asr_swbd/local/swbd1_data_download.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_data_download.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/swbd1_data_download.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_data_prep.sh b/examples/asr_swbd/local/swbd1_data_prep.sh index 7faee28eb..099935007 120000 --- a/examples/asr_swbd/local/swbd1_data_prep.sh +++ b/examples/asr_swbd/local/swbd1_data_prep.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_data_prep.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/swbd1_data_prep.sh \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_fix_speakerid.pl b/examples/asr_swbd/local/swbd1_fix_speakerid.pl index 83a348533..d4aabbd75 120000 --- a/examples/asr_swbd/local/swbd1_fix_speakerid.pl +++ b/examples/asr_swbd/local/swbd1_fix_speakerid.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_fix_speakerid.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/swbd1_fix_speakerid.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_map_words.pl b/examples/asr_swbd/local/swbd1_map_words.pl index f35ddcb7f..d13805d9f 120000 --- a/examples/asr_swbd/local/swbd1_map_words.pl +++ b/examples/asr_swbd/local/swbd1_map_words.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_map_words.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/swbd1_map_words.pl \ No newline at end of file diff --git a/examples/asr_swbd/local/swbd1_prepare_dict.sh b/examples/asr_swbd/local/swbd1_prepare_dict.sh index 2b5a643c7..43789ea0c 120000 --- a/examples/asr_swbd/local/swbd1_prepare_dict.sh +++ b/examples/asr_swbd/local/swbd1_prepare_dict.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/swbd/s5c/local/swbd1_prepare_dict.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/swbd/s5c/local/swbd1_prepare_dict.sh \ No newline at end of file diff --git a/examples/asr_swbd/path.sh b/examples/asr_swbd/path.sh index 19308a610..863f5de3e 100644 --- a/examples/asr_swbd/path.sh +++ b/examples/asr_swbd/path.sh @@ -1,5 +1,5 @@ MAIN_ROOT=$PWD/../.. -KALDI_ROOT=$MAIN_ROOT/speech_tools/kaldi +KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi # BEGIN from kaldi path.sh [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh @@ -10,6 +10,6 @@ export LC_ALL=C # END export PATH=~/anaconda3/bin:$PATH -export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH -export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH +export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 4b9b0aa7d..43e85a31e 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -101,14 +101,14 @@ if [ $stage -le 1 ]; then # dump features for training if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $train_feat_dir/storage ]; then - utils/create_split_dir.pl \ - /export/b1{4,5,6,7}/$USER/fairseq-data/egs/asr_swbd/dump/$train_set/delta${do_delta}/storage \ - $train_feat_dir/storage + utils/create_split_dir.pl \ + /export/b1{4,5,6,7}/$USER/fairseq-data/egs/asr_swbd/dump/$train_set/delta${do_delta}/storage \ + $train_feat_dir/storage fi if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $valid_feat_dir/storage ]; then - utils/create_split_dir.pl \ - /export/b1{4,5,6,7}/$USER/fairseq-data/egs/asr_swbd/dump/$valid_set/delta${do_delta}/storage \ - $valid_feat_dir/storage + utils/create_split_dir.pl \ + /export/b1{4,5,6,7}/$USER/fairseq-data/egs/asr_swbd/dump/$valid_set/delta${do_delta}/storage \ + $valid_feat_dir/storage fi dump.sh --cmd "$train_cmd" --nj 32 --do_delta $do_delta \ data/$train_set/feats.scp data/$train_set/cmvn.ark exp/dump_feats/train $train_feat_dir @@ -186,7 +186,7 @@ if [ $stage -le 3 ]; then test_paths= && for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/logs/preprocess.log \ - python3 ../../preprocess.py --task language_modeling_for_asr \ + python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -206,7 +206,7 @@ if [ $stage -le 4 ]; then mkdir -p $lmdir/logs log_file=$lmdir/logs/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval 500 --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 1024 \ @@ -227,7 +227,7 @@ if [ $stage -le 5 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log - python3 ../../eval_lm.py $lmdatadir \ + python3 ../../eval_lm.py $lmdatadir --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -246,7 +246,7 @@ if [ $stage -le 6 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 --user-dir espresso \ --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -281,7 +281,7 @@ if [ $stage -le 7 ]; then # only score train_dev with built-in scorer text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --max-tokens 24000 --max-sentences 48 --num-shards 1 --shard-id 0 \ + --user-dir espresso --max-tokens 24000 --max-sentences 48 --num-shards 1 --shard-id 0 \ --test-feat-files ${dumpdir}/$dataset/delta${do_delta}/feats.scp $text_opt \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ diff --git a/examples/asr_swbd/steps b/examples/asr_swbd/steps index ec9b528ac..5871b3f9c 120000 --- a/examples/asr_swbd/steps +++ b/examples/asr_swbd/steps @@ -1 +1 @@ -../../speech_tools/kaldi/egs/wsj/s5/steps \ No newline at end of file +../../espresso/tools/kaldi/egs/wsj/s5/steps \ No newline at end of file diff --git a/examples/asr_swbd/utils b/examples/asr_swbd/utils index ea44d93b9..bc8958e91 120000 --- a/examples/asr_swbd/utils +++ b/examples/asr_swbd/utils @@ -1 +1 @@ -../../speech_tools/kaldi/egs/wsj/s5/utils \ No newline at end of file +../../espresso/tools/kaldi/egs/wsj/s5/utils \ No newline at end of file diff --git a/examples/asr_wsj/local/find_transcripts.pl b/examples/asr_wsj/local/find_transcripts.pl index 2455a71d6..0dc047717 120000 --- a/examples/asr_wsj/local/find_transcripts.pl +++ b/examples/asr_wsj/local/find_transcripts.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/wsj/s5/local/find_transcripts.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/wsj/s5/local/find_transcripts.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/flist2scp.pl b/examples/asr_wsj/local/flist2scp.pl index e7c6f9da4..44508e89f 120000 --- a/examples/asr_wsj/local/flist2scp.pl +++ b/examples/asr_wsj/local/flist2scp.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/wsj/s5/local/flist2scp.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/wsj/s5/local/flist2scp.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/ndx2flist.pl b/examples/asr_wsj/local/ndx2flist.pl index 2c868304e..2e330c631 120000 --- a/examples/asr_wsj/local/ndx2flist.pl +++ b/examples/asr_wsj/local/ndx2flist.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/wsj/s5/local/ndx2flist.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/wsj/s5/local/ndx2flist.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/normalize_transcript.pl b/examples/asr_wsj/local/normalize_transcript.pl index 975e24acf..c8891c186 120000 --- a/examples/asr_wsj/local/normalize_transcript.pl +++ b/examples/asr_wsj/local/normalize_transcript.pl @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/wsj/s5/local/normalize_transcript.pl \ No newline at end of file +../../../espresso/tools/kaldi/egs/wsj/s5/local/normalize_transcript.pl \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_data_prep.sh b/examples/asr_wsj/local/wsj_data_prep.sh index f909e21b5..ce5ccf527 120000 --- a/examples/asr_wsj/local/wsj_data_prep.sh +++ b/examples/asr_wsj/local/wsj_data_prep.sh @@ -1 +1 @@ -../../../speech_tools/kaldi/egs/wsj/s5/local/wsj_data_prep.sh \ No newline at end of file +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_data_prep.sh \ No newline at end of file diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh index 19308a610..863f5de3e 100644 --- a/examples/asr_wsj/path.sh +++ b/examples/asr_wsj/path.sh @@ -1,5 +1,5 @@ MAIN_ROOT=$PWD/../.. -KALDI_ROOT=$MAIN_ROOT/speech_tools/kaldi +KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi # BEGIN from kaldi path.sh [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh @@ -10,6 +10,6 @@ export LC_ALL=C # END export PATH=~/anaconda3/bin:$PATH -export PATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PATH -export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/speech_tools:$PYTHONPATH +export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 4c6ad2cea..33338ba17 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -161,7 +161,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing char text..." mkdir -p $lmdatadir/logs ${decode_cmd} $lmdatadir/logs/preprocess.log \ - python3 ../../preprocess.py --task language_modeling_for_asr \ + python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 30 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -171,7 +171,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing word text..." mkdir -p $wordlmdatadir/logs ${decode_cmd} $wordlmdatadir/logs/preprocess.log \ - python3 ../../preprocess.py --task language_modeling_for_asr \ + python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 30 --srcdict $wordlmdict --only-source \ --trainpref $wordlmdatadir/train \ --validpref $wordlmdatadir/$valid_set \ @@ -192,7 +192,7 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then mkdir -p $lmdir/logs log_file=$lmdir/logs/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval 2000 --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 128 \ @@ -209,7 +209,7 @@ if [ ${stage} -le 5 ] && ! $use_wordlm; then echo "Stage 5: char LM Evaluation" for gen_subset in valid test; do log_file=$lmdir/logs/evaluation_$gen_subset.log - python3 ../../eval_lm.py $lmdatadir --cpu \ + python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ --max-tokens 192000 --max-sentences 256 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -222,7 +222,7 @@ if [ ${stage} -le 6 ] && $use_wordlm; then mkdir -p $wordlmdir/logs log_file=$wordlmdir/logs/train.log [ -f $wordlmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $wordlmdict \ --log-interval 2000 --log-format simple \ --num-workers 0 --max-tokens 6400 --max-sentences 256 \ @@ -240,7 +240,7 @@ if [ ${stage} -le 7 ] && $use_wordlm; then echo "Stage 7: word LM Evaluation" for gen_subset in valid test; do log_file=$wordlmdir/logs/evaluation_$gen_subset.log - python3 ../../eval_lm.py $wordlmdatadir --cpu \ + python3 ../../eval_lm.py $wordlmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $wordlmdict --gen-subset $gen_subset \ --max-tokens 12800 --max-sentences 512 --sample-break-mode eos \ --path $wordlmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -264,7 +264,7 @@ if [ ${stage} -le 8 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 --user-dir espresso \ --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -307,7 +307,7 @@ if [ ${stage} -le 9 ]; then fi text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --max-tokens 20000 --max-sentences 32 --num-shards 1 --shard-id 0 \ + --user-dir espresso --max-tokens 20000 --max-sentences 32 --num-shards 1 --shard-id 0 \ --test-feat-files $feat --test-text-files $text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ diff --git a/examples/asr_wsj/steps b/examples/asr_wsj/steps index ec9b528ac..5871b3f9c 120000 --- a/examples/asr_wsj/steps +++ b/examples/asr_wsj/steps @@ -1 +1 @@ -../../speech_tools/kaldi/egs/wsj/s5/steps \ No newline at end of file +../../espresso/tools/kaldi/egs/wsj/s5/steps \ No newline at end of file diff --git a/examples/asr_wsj/utils b/examples/asr_wsj/utils index ea44d93b9..bc8958e91 120000 --- a/examples/asr_wsj/utils +++ b/examples/asr_wsj/utils @@ -1 +1 @@ -../../speech_tools/kaldi/egs/wsj/s5/utils \ No newline at end of file +../../espresso/tools/kaldi/egs/wsj/s5/utils \ No newline at end of file diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index ae7c0f1c9..9b3081395 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -56,9 +56,6 @@ from .multilingual.sampled_multi_dataset import SampledMultiDataset from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset from .fasta_dataset import FastaDataset, EncodedFastaDataset -from .token_dictionary import TokenDictionary -from .scp_dataset import ScpDataset, ScpCachedDataset, ScpInMemoryDataset, TokenTextDataset -from .speech_dataset import SpeechDataset from .iterators import ( CountingIterator, @@ -124,10 +121,4 @@ "TransformEosLangPairDataset", "TruncateDataset", "TruncatedDictionary", - 'TokenDictionary', - 'ScpDataset', - 'ScpCachedDataset', - 'ScpInMemoryDataset', - 'TokenTextDataset', - 'SpeechDataset', ] diff --git a/tests/espresso/__init__.py b/tests/espresso/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_speech_dataset.py b/tests/espresso/test_speech_dataset.py similarity index 97% rename from tests/test_speech_dataset.py rename to tests/espresso/test_speech_dataset.py index 33f4a8529..eb48acc4c 100644 --- a/tests/test_speech_dataset.py +++ b/tests/espresso/test_speech_dataset.py @@ -10,9 +10,13 @@ import torch -from fairseq.data import ( - SpeechDataset, TokenDictionary, TokenTextDataset, ScpCachedDataset, - ScpInMemoryDataset) +from espresso.data import ( + ScpCachedDataset, + ScpInMemoryDataset, + SpeechDataset, + TokenDictionary, + TokenTextDataset, +) try: import kaldi_io diff --git a/tests/test_speech_utils.py b/tests/espresso/test_speech_utils.py similarity index 99% rename from tests/test_speech_utils.py rename to tests/espresso/test_speech_utils.py index d3aa0945b..9a3c0b2a9 100644 --- a/tests/test_speech_utils.py +++ b/tests/espresso/test_speech_utils.py @@ -10,9 +10,9 @@ import torch -from fairseq.data import TokenDictionary +from espresso.data import TokenDictionary -import speech_tools.utils as utils +import espresso.tools.utils as utils class TestSpeechUtils(unittest.TestCase): From 9dd8992b335835130f42ccc7abdfbfa421dae0ea Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 17 Dec 2019 01:24:52 -0500 Subject: [PATCH 051/119] remove coverage term for beam search decoding as it has been superceded by eos thresholding --- espresso/speech_recognize.py | 6 +----- espresso/speech_train.py | 12 +++++------- espresso/tasks/speech_recognition.py | 1 - examples/asr_librispeech/run.sh | 8 ++++---- examples/asr_swbd/run.sh | 8 ++++---- examples/asr_wsj/run.sh | 10 +++++----- 6 files changed, 19 insertions(+), 26 deletions(-) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index e198990c1..5ca59aa6b 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -75,7 +75,7 @@ def main(args): for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, - need_attn=args.print_alignment or args.coverage_weight > 0., + need_attn=args.print_alignment, ) if args.fp16: model.half() @@ -229,10 +229,6 @@ def print_options_meaning_changes(args): def cli_main(): parser = options.get_generation_parser(default_task='speech_recognition') - parser.add_argument('--coverage-weight', default=0.0, type=float, metavar='W', - help='coverage weight in log-prob space, mostly to ' - 'reduce deletion errors while using the pretrained ' - 'external LM for decoding') parser.add_argument('--eos-factor', default=None, type=float, metavar='F', help='only consider emitting EOS if its score is no less ' 'than the specified factor of the best candidate score') diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 54d158bb8..6b46df933 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -19,7 +19,6 @@ from fairseq.data import iterators from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter -from fairseq.utils import import_user_module def main(args, init_distributed=False): @@ -71,7 +70,7 @@ def main(args, init_distributed=False): # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) - if callable(getattr(trainer.criterion, 'set_train_tgt_dataset', None)): + if hasattr(trainer.criterion, 'set_train_tgt_dataset'): trainer.criterion.set_train_tgt_dataset(task.dataset(args.train_subset).tgt) # Train until the learning rate gets too small @@ -111,7 +110,6 @@ def main(args, init_distributed=False): def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" - # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] @@ -129,10 +127,10 @@ def train(args, trainer, task, epoch_itr): extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf - if callable(getattr(trainer.criterion, 'set_epoch', None)): + if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): - if callable(getattr(trainer.criterion, 'set_num_updates', None)): + if hasattr(trainer.criterion, 'set_num_updates'): trainer.criterion.set_num_updates(trainer.get_num_updates()) log_output = trainer.train_step(samples) @@ -228,7 +226,7 @@ def validate(args, trainer, task, epoch_itr, subsets): trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=8, + required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, @@ -247,7 +245,7 @@ def validate(args, trainer, task, epoch_itr, subsets): meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) - if callable(getattr(trainer.criterion, 'set_valid_tgt_dataset', None)): + if hasattr(trainer.criterion, 'set_valid_tgt_dataset'): trainer.criterion.set_valid_tgt_dataset(task.dataset(subset).tgt) for sample in progress: diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 8e7fb64e3..16f31f426 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -247,7 +247,6 @@ def build_generator(self, args): diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5), match_source_len=getattr(args, 'match_source_len', False), no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), - coverage_weight=getattr(args, 'coverage_weight', 0.0), eos_factor=getattr(args, 'eos_factor', None), ) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 1082617ed..cd854cd37 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -205,7 +205,7 @@ if [ ${stage} -le 7 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition --seed 1 --user-dir espresso \ --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ @@ -230,15 +230,15 @@ if [ ${stage} -le 8 ]; then decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.47 --coverage-weight 0.0 --eos-factor 1.5" + opts="$opts --lm-weight 0.47 --eos-factor 1.5" decode_affix=shallow_fusion fi for dataset in $test_set; do feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --user-dir espresso --max-tokens 15000 --max-sentences 24 --num-shards 1 --shard-id 0 \ - --test-feat-files $feat --test-text-files $text \ + --task speech_recognition --user-dir espresso --max-tokens 15000 --max-sentences 24 \ + --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 43e85a31e..634952d0e 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -246,7 +246,7 @@ if [ $stage -le 6 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition --seed 1 --user-dir espresso \ --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -272,7 +272,7 @@ if [ $stage -le 7 ]; then decode_affix= if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.25 --coverage-weight 0.0" + opts="$opts --lm-weight 0.25" decode_affix=shallow_fusion fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" @@ -281,8 +281,8 @@ if [ $stage -le 7 ]; then # only score train_dev with built-in scorer text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --user-dir espresso --max-tokens 24000 --max-sentences 48 --num-shards 1 --shard-id 0 \ - --test-feat-files ${dumpdir}/$dataset/delta${do_delta}/feats.scp $text_opt \ + --task speech_recognition --user-dir espresso --max-tokens 24000 --max-sentences 48 \ + --num-shards 1 --shard-id 0 --test-feat-files ${dumpdir}/$dataset/delta${do_delta}/feats.scp $text_opt \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 33338ba17..363d7afcd 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -264,7 +264,7 @@ if [ ${stage} -le 8 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition --seed 1 --user-dir espresso \ --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -290,11 +290,11 @@ if [ ${stage} -le 9 ]; then if $lm_shallow_fusion; then if ! $use_wordlm; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.7 --coverage-weight 0.01" + opts="$opts --lm-weight 0.7 --coverage-weight 0.01 --eos-factor 1.5" decode_affix=shallow_fusion else path="$path:$wordlmdir/$lm_checkpoint" - opts="$opts --word-dict $wordlmdict --lm-weight 0.9 --oov-penalty 1e-7 --coverage-weight 0.0 --eos-factor 1.5" + opts="$opts --word-dict $wordlmdict --lm-weight 0.9 --oov-penalty 1e-7 --eos-factor 1.5" decode_affix=shallow_fusion_wordlm fi fi @@ -307,8 +307,8 @@ if [ ${stage} -le 9 ]; then fi text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --user-dir espresso --max-tokens 20000 --max-sentences 32 --num-shards 1 --shard-id 0 \ - --test-feat-files $feat --test-text-files $text \ + --task speech_recognition --user-dir espresso --max-tokens 20000 --max-sentences 32 \ + --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ From 69dbd15ab49e1dcd743a796f280d88f64757cb61 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 18 Dec 2019 01:26:59 -0500 Subject: [PATCH 052/119] fix bugs causing build failure; a bunch of lint changes; rename TokenDictionary->AsrDictionary, TokenTextDataset->AsrTextDataset --- espresso/__init__.py | 1 - espresso/criterions/cross_entropy_with_wer.py | 84 ++++---- .../label_smoothed_cross_entropy_with_wer.py | 102 ++++++---- espresso/data/__init__.py | 10 +- ...{token_dictionary.py => asr_dictionary.py} | 10 +- .../{scp_dataset.py => scp_text_dataset.py} | 60 +++--- espresso/data/speech_dataset.py | 17 +- espresso/models/external_language_model.py | 185 ++++++++++-------- espresso/models/speech_fconv.py | 51 ++--- espresso/models/speech_lstm.py | 106 +++++----- espresso/models/speech_transformer.py | 56 +++--- .../tensorized_lookahead_language_model.py | 31 +-- espresso/modules/speech_attention.py | 16 +- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 6 +- espresso/speech_recognize.py | 31 +-- espresso/speech_train.py | 2 +- espresso/tasks/language_modeling_for_asr.py | 16 +- espresso/tasks/speech_recognition.py | 43 ++-- espresso/tools/compute_wer.py | 18 +- espresso/tools/lexical_prefix_tree.py | 64 ++++++ espresso/tools/tensorized_prefix_tree.py | 11 +- espresso/tools/text2token.py | 7 +- espresso/tools/text2vocabulary.py | 5 +- espresso/tools/utils.py | 95 +++------ espresso/tools/wer.py | 39 ++-- examples/asr_librispeech/run.sh | 4 +- examples/asr_swbd/local/prepare_ctm.py | 14 +- examples/asr_swbd/run.sh | 4 +- examples/asr_wsj/run.sh | 4 +- setup.py | 1 + tests/espresso/test_speech_dataset.py | 51 +++-- tests/espresso/test_speech_utils.py | 93 +++++---- 32 files changed, 685 insertions(+), 552 deletions(-) rename espresso/data/{token_dictionary.py => asr_dictionary.py} (94%) rename espresso/data/{scp_dataset.py => scp_text_dataset.py} (82%) create mode 100644 espresso/tools/lexical_prefix_tree.py diff --git a/espresso/__init__.py b/espresso/__init__.py index 98cb3449e..666272541 100644 --- a/espresso/__init__.py +++ b/espresso/__init__.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import espresso.criterions # noqa import espresso.models # noqa import espresso.modules # noqa diff --git a/espresso/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py index ad3ebd14e..91a4d7b45 100644 --- a/espresso/criterions/cross_entropy_with_wer.py +++ b/espresso/criterions/cross_entropy_with_wer.py @@ -12,7 +12,7 @@ from fairseq.models import FairseqIncrementalDecoder from fairseq.options import eval_str_list -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import register_criterion from fairseq.criterions.cross_entropy import CrossEntropyCriterion from espresso.tools import wer @@ -25,8 +25,7 @@ def __init__(self, args, task): super().__init__(args, task) dict = task.target_dictionary - self.scorer = wer.Scorer(dict, - wer_output_filter=task.args.wer_output_filter) + self.scorer = wer.Scorer(dict, wer_output_filter=task.args.wer_output_filter) self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 @@ -62,13 +61,15 @@ def forward(self, model, sample, reduce=True): """ dict = self.scorer.dict if model.training: - if (len(self.args.scheduled_sampling_probs) > 1 or \ - self.args.scheduled_sampling_probs[0] < 1.0) and \ - self.epoch >= self.args.start_scheduled_sampling_epoch: + if ( + (len(self.args.scheduled_sampling_probs) > 1 or + self.args.scheduled_sampling_probs[0] < 1.0) and + self.epoch >= self.args.start_scheduled_sampling_epoch + ): # scheduled sampling ss_prob = self.args.scheduled_sampling_probs[ min(self.epoch - self.args.start_scheduled_sampling_epoch, - len(self.args.scheduled_sampling_probs) - 1) + len(self.args.scheduled_sampling_probs) - 1) ] assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} @@ -80,17 +81,22 @@ def forward(self, model, sample, reduce=True): target = sample['target'] tokens = sample['net_input']['prev_output_tokens'] lprobs = [] + pred = None for step in range(target.size(1)): if step > 0: - sampling_mask = torch.rand([target.size(0), 1], - device=target.device).lt(ss_prob) - feed_tokens = torch.where(sampling_mask, - tokens[:, step:step + 1], pred) + sampling_mask = torch.rand( + [target.size(0), 1], + device=target.device, + ).lt(ss_prob) + feed_tokens = torch.where( + sampling_mask, tokens[:, step:step + 1], pred, + ) else: feed_tokens = tokens[:, step:step + 1] - log_probs, _ = self._decode(feed_tokens, - model, encoder_out, incremental_states) - pred = log_probs.argmax(-1,keepdim=True) + log_probs, _ = self._decode( + feed_tokens, model, encoder_out, incremental_states, + ) + pred = log_probs.argmax(-1, keepdim=True) lprobs.append(log_probs) lprobs = torch.stack(lprobs, dim=1) else: @@ -116,7 +122,7 @@ def forward(self, model, sample, reduce=True): attn = [] if getattr(model.decoder, 'need_attn', False) else None dummy_log_probs = encoder_out['encoder_out'][0].new_full( [target.size(0), len(dict)], -np.log(len(dict))) - for step in range(maxlen + 1): # one extra step for EOS marker + for step in range(maxlen + 1): # one extra step for EOS marker is_eos = tokens[:, step].eq(dict.eos()) # if all predictions are finished (i.e., ended with eos), # pad lprobs to target length with dummy log probs, @@ -126,10 +132,11 @@ def forward(self, model, sample, reduce=True): lprobs.append(dummy_log_probs) tokens = tokens[:, :step + 1] break - log_probs, attn_scores = self._decode(tokens[:, :step + 1], - model, encoder_out, incremental_states) + log_probs, attn_scores = self._decode( + tokens[:, :step + 1], model, encoder_out, incremental_states, + ) tokens[:, step + 1] = log_probs.argmax(-1) - if step > 0: # deal with finished predictions + if step > 0: # deal with finished predictions # make log_probs uniform if the previous output token is EOS # and add consecutive EOS to the end of prediction log_probs[is_eos, :] = -np.log(log_probs.size(1)) @@ -144,35 +151,43 @@ def forward(self, model, sample, reduce=True): # bsz x (maxlen + 1) x (length of encoder_out) attn = torch.stack(attn, dim=1) # word error stats code starts - if not model.training or (self.num_updates // self.args.print_interval > - (self.num_updates - 1) // self.args.print_interval): + if ( + not model.training or + ( + self.num_updates // self.args.print_interval > + (self.num_updates - 1) // self.args.print_interval + ) + ): pred = lprobs.argmax(-1).cpu() if model.training else \ - tokens[:, 1:].data.cpu() # bsz x len + tokens[:, 1:].data.cpu() # bsz x len - if not model.training: # validation step, compute WER stats with scorer + if not model.training: # validation step, compute WER stats with scorer assert pred.size(0) == target.size(0) self.scorer.reset() for i in range(target.size(0)): utt_id = sample['utt_id'][i] id = sample['id'].data[i].item() - #ref_tokens = dict.string(target.data[i]) + # ref_tokens = dict.string(target.data[i]) # if it is a dummy batch (e.g., a "padding" batch in a sharded # dataset), id might exceeds the dataset size; in that case we # just skip it if id < len(self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) pred_tokens = dict.string(pred.data[i]) - self.scorer.add_evaluation(utt_id, ref_tokens, - pred_tokens, bpe_symbol=self.args.remove_bpe) - else: # print a randomly sampled result every print_interval updates + self.scorer.add_evaluation( + utt_id, ref_tokens, pred_tokens, + bpe_symbol=self.args.remove_bpe, + ) + else: # print a randomly sampled result every print_interval updates assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) - ref_one = self.train_tgt_dataset.get_original_text(id, dict, - bpe_symbol=self.args.remove_bpe) + # ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) + ref_one = self.train_tgt_dataset.get_original_text( + id, dict, bpe_symbol=self.args.remove_bpe, + ) pred_one = dict.tokens_to_sentence( dict.string(pred.data[i][:length]), bpe_symbol=self.args.remove_bpe, @@ -195,7 +210,7 @@ def forward(self, model, sample, reduce=True): 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } - if not model.training: # do not compute word error in training mode + if not model.training: # do not compute word error in training mode logging_output['word_error'] = self.scorer.tot_word_error() logging_output['word_count'] = self.scorer.tot_word_count() logging_output['char_error'] = self.scorer.tot_char_error() @@ -210,17 +225,18 @@ def aggregate_logging_outputs(logging_outputs): word_count = sum(log.get('word_count', 0) for log in logging_outputs) char_error = sum(log.get('char_error', 0) for log in logging_outputs) char_count = sum(log.get('char_count', 0) for log in logging_outputs) - if word_count > 0: # model.training == False + if word_count > 0: # model.training == False agg_output['word_error'] = word_error agg_output['word_count'] = word_count - if char_count > 0: # model.training == False + if char_count > 0: # model.training == False agg_output['char_error'] = char_error agg_output['char_count'] = char_count return agg_output def _decode(self, tokens, model, encoder_out, incremental_states): - decoder_out = list(model.decoder(tokens, encoder_out, - incremental_state=incremental_states)) + decoder_out = list(model.forward_decoder( + tokens, encoder_out=encoder_out, incremental_state=incremental_states, + )) decoder_out[0] = decoder_out[0][:, -1:, :] attn = decoder_out[1] if type(attn) is dict: diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index dca7c1e40..1fdc44785 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -11,14 +11,16 @@ from fairseq.models import FairseqIncrementalDecoder from fairseq.options import eval_str_list -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from espresso.tools import wer -def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, - smoothing_type='uniform', prob_mask=None, unigram_tensor=None): +def label_smoothed_nll_loss( + lprobs, target, epsilon, ignore_index=None, reduce=True, + smoothing_type='uniform', prob_mask=None, unigram_tensor=None, +): if target.dim() == lprobs.dim() - 1: target = target.unsqueeze(-1) nll_loss = -lprobs.gather(dim=-1, index=target) @@ -54,8 +56,7 @@ def __init__(self, args, task): super().__init__(args, task) dict = task.target_dictionary - self.scorer = wer.Scorer(dict, - wer_output_filter=task.args.wer_output_filter) + self.scorer = wer.Scorer(dict, wer_output_filter=task.args.wer_output_filter) self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 @@ -105,13 +106,15 @@ def forward(self, model, sample, reduce=True): """ dict = self.scorer.dict if model.training: - if (len(self.args.scheduled_sampling_probs) > 1 or \ - self.args.scheduled_sampling_probs[0] < 1.0) and \ - self.epoch >= self.args.start_scheduled_sampling_epoch: + if ( + (len(self.args.scheduled_sampling_probs) > 1 or + self.args.scheduled_sampling_probs[0] < 1.0) and + self.epoch >= self.args.start_scheduled_sampling_epoch + ): # scheduled sampling ss_prob = self.args.scheduled_sampling_probs[ min(self.epoch - self.args.start_scheduled_sampling_epoch, - len(self.args.scheduled_sampling_probs) - 1) + len(self.args.scheduled_sampling_probs) - 1) ] assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} @@ -123,16 +126,21 @@ def forward(self, model, sample, reduce=True): target = sample['target'] tokens = sample['net_input']['prev_output_tokens'] lprobs = [] + pred = None for step in range(target.size(1)): if step > 0: - sampling_mask = torch.rand([target.size(0), 1], - device=target.device).lt(ss_prob) - feed_tokens = torch.where(sampling_mask, - tokens[:, step:step + 1], pred) + sampling_mask = torch.rand( + [target.size(0), 1], + device=target.device, + ).lt(ss_prob) + feed_tokens = torch.where( + sampling_mask, tokens[:, step:step + 1], pred, + ) else: feed_tokens = tokens[:, step:step + 1] - log_probs, _ = self._decode(feed_tokens, - model, encoder_out, incremental_states) + log_probs, _ = self._decode( + feed_tokens, model, encoder_out, incremental_states, + ) pred = log_probs.argmax(-1, keepdim=True) lprobs.append(log_probs) lprobs = torch.stack(lprobs, dim=1) @@ -159,7 +167,7 @@ def forward(self, model, sample, reduce=True): attn = [] if getattr(model.decoder, 'need_attn', False) else None dummy_log_probs = encoder_out['encoder_out'][0].new_full( [target.size(0), len(dict)], -np.log(len(dict))) - for step in range(maxlen + 1): # one extra step for EOS marker + for step in range(maxlen + 1): # one extra step for EOS marker is_eos = tokens[:, step].eq(dict.eos()) # if all predictions are finished (i.e., ended with eos), # pad lprobs to target length with dummy log probs, @@ -169,10 +177,11 @@ def forward(self, model, sample, reduce=True): lprobs.append(dummy_log_probs) tokens = tokens[:, :step + 1] break - log_probs, attn_scores = self._decode(tokens[:, :step + 1], - model, encoder_out, incremental_states) + log_probs, attn_scores = self._decode( + tokens[:, :step + 1], model, encoder_out, incremental_states, + ) tokens[:, step + 1] = log_probs.argmax(-1) - if step > 0: # deal with finished predictions + if step > 0: # deal with finished predictions # make log_probs uniform if the previous output token is EOS # and add consecutive EOS to the end of prediction log_probs[is_eos, :] = -np.log(log_probs.size(1)) @@ -187,35 +196,43 @@ def forward(self, model, sample, reduce=True): # bsz x (maxlen + 1) x (length of encoder_out) attn = torch.stack(attn, dim=1) # word error stats code starts - if not model.training or (self.num_updates // self.args.print_interval > - (self.num_updates - 1) // self.args.print_interval): + if ( + not model.training or + ( + self.num_updates // self.args.print_interval > + (self.num_updates - 1) // self.args.print_interval + ) + ): pred = lprobs.argmax(-1).cpu() if model.training else \ - tokens[:, 1:].data.cpu() # bsz x len + tokens[:, 1:].data.cpu() # bsz x len - if not model.training: # validation step, compute WER stats with scorer + if not model.training: # validation step, compute WER stats with scorer assert pred.size(0) == target.size(0) self.scorer.reset() for i in range(target.size(0)): utt_id = sample['utt_id'][i] id = sample['id'].data[i].item() - #ref_tokens = dict.string(target.data[i]) + # ref_tokens = dict.string(target.data[i]) # if it is a dummy batch (e.g., a "padding" batch in a sharded # dataset), id might exceeds the dataset size; in that case we # just skip it if id < len(self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) pred_tokens = dict.string(pred.data[i]) - self.scorer.add_evaluation(utt_id, ref_tokens, - pred_tokens, bpe_symbol=self.args.remove_bpe) - else: # print a randomly sampled result every print_interval updates + self.scorer.add_evaluation( + utt_id, ref_tokens, pred_tokens, + bpe_symbol=self.args.remove_bpe, + ) + else: # print a randomly sampled result every print_interval updates assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - #ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) - ref_one = self.train_tgt_dataset.get_original_text(id, dict, - bpe_symbol=self.args.remove_bpe) + # ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) + ref_one = self.train_tgt_dataset.get_original_text( + id, dict, bpe_symbol=self.args.remove_bpe, + ) pred_one = dict.tokens_to_sentence( dict.string(pred.data[i][:length]), bpe_symbol=self.args.remove_bpe, @@ -227,22 +244,22 @@ def forward(self, model, sample, reduce=True): if self.args.smoothing_type == 'temporal': # see https://arxiv.org/pdf/1612.02695.pdf # prob_mask.dtype=int for deterministic behavior of Tensor.scatter_add_() - prob_mask = torch.zeros_like(lprobs, dtype=torch.int) # bsz x tgtlen x vocab_size - idx_tensor = target.new_full(target.size(), self.padding_idx).unsqueeze(-1) # bsz x tgtlen x 1 + prob_mask = torch.zeros_like(lprobs, dtype=torch.int) # bsz x tgtlen x vocab_size + idx_tensor = target.new_full(target.size(), self.padding_idx).unsqueeze(-1) # bsz x tgtlen x 1 # hard-code the remaining probabilty mass distributed symmetrically # over neighbors at distance ±1 and ±2 with a 5 : 2 ratio - idx_tensor[:, 2:, 0] = target[:, :-2] # two neighbors to the left + idx_tensor[:, 2:, 0] = target[:, :-2] # two neighbors to the left prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) idx_tensor.fill_(self.padding_idx)[:, 1:, 0] = target[:, :-1] prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) - idx_tensor.fill_(self.padding_idx)[:, :-2, 0] = target[:, 2:] # two neighbors to the right + idx_tensor.fill_(self.padding_idx)[:, :-2, 0] = target[:, 2:] # two neighbors to the right prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) idx_tensor.fill_(self.padding_idx)[:, :-1, 0] = target[:, 1:] prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) - prob_mask[:, :, self.padding_idx] = 0 # clear cumulative count on - prob_mask = prob_mask.float() # convert to float + prob_mask[:, :, self.padding_idx] = 0 # clear cumulative count on + prob_mask = prob_mask.float() # convert to float sum_prob = prob_mask.sum(-1, keepdim=True) - sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem + sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1)) lprobs = lprobs.view(-1, lprobs.size(-1)) @@ -260,7 +277,7 @@ def forward(self, model, sample, reduce=True): 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } - if not model.training: # do not compute word error in training mode + if not model.training: # do not compute word error in training mode logging_output['word_error'] = self.scorer.tot_word_error() logging_output['word_count'] = self.scorer.tot_word_count() logging_output['char_error'] = self.scorer.tot_char_error() @@ -275,17 +292,18 @@ def aggregate_logging_outputs(logging_outputs): word_count = sum(log.get('word_count', 0) for log in logging_outputs) char_error = sum(log.get('char_error', 0) for log in logging_outputs) char_count = sum(log.get('char_count', 0) for log in logging_outputs) - if word_count > 0: # model.training == False + if word_count > 0: # model.training == False agg_output['word_error'] = word_error agg_output['word_count'] = word_count - if char_count > 0: # model.training == False + if char_count > 0: # model.training == False agg_output['char_error'] = char_error agg_output['char_count'] = char_count return agg_output def _decode(self, tokens, model, encoder_out, incremental_states): - decoder_out = list(model.decoder(tokens, encoder_out, - incremental_state=incremental_states)) + decoder_out = list(model.forward_decoder( + tokens, encoder_out=encoder_out, incremental_state=incremental_states, + )) decoder_out[0] = decoder_out[0][:, -1:, :] attn = decoder_out[1] if type(attn) is dict: diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py index b8bfbc584..d2fcf423c 100644 --- a/espresso/data/__init__.py +++ b/espresso/data/__init__.py @@ -3,15 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .token_dictionary import TokenDictionary -from .scp_dataset import ScpDataset, ScpCachedDataset, ScpInMemoryDataset, TokenTextDataset +from .asr_dictionary import AsrDictionary +from .scp_text_dataset import AsrTextDataset, ScpCachedDataset, ScpDataset, ScpInMemoryDataset from .speech_dataset import SpeechDataset __all__ = [ - 'ScpDataset', + 'AsrDictionary', + 'AsrTextDataset', 'ScpCachedDataset', + 'ScpDataset', 'ScpInMemoryDataset', - 'TokenDictionary', - 'TokenTextDataset', 'SpeechDataset', ] diff --git a/espresso/data/token_dictionary.py b/espresso/data/asr_dictionary.py similarity index 94% rename from espresso/data/token_dictionary.py rename to espresso/data/asr_dictionary.py index fc9e6fbef..570c94add 100644 --- a/espresso/data/token_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -4,12 +4,13 @@ # LICENSE file in the root directory of this source tree. import torch + from fairseq.data import Dictionary, data_utils from fairseq.file_io import PathManager from fairseq.tokenizer import tokenize_line -class TokenDictionary(Dictionary): +class AsrDictionary(Dictionary): """A mapping from symbols to consecutive integers""" def __init__( @@ -107,7 +108,7 @@ def load(cls, f, f_non_lang_syms=None): for sym in non_lang_syms: assert d.index(sym) != d.unk(), \ - "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) + "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) d.non_lang_syms = non_lang_syms return d @@ -118,8 +119,9 @@ def dummy_sentence(self, length): t[-1] = self.eos() return t - def tokens_to_sentence(self, line, line_tokenizer=tokenize_line, - use_unk_sym=True, bpe_symbol=None): + def tokens_to_sentence( + self, line, line_tokenizer=tokenize_line, use_unk_sym=True, bpe_symbol=None, + ): if bpe_symbol is not None: return data_utils.process_bpe_symbol(line, bpe_symbol) # use_unk_sym=False when we want to restore original transcripts from diff --git a/espresso/data/scp_dataset.py b/espresso/data/scp_text_dataset.py similarity index 82% rename from espresso/data/scp_dataset.py rename to espresso/data/scp_text_dataset.py index 02a07a3b1..c976e78e0 100644 --- a/espresso/data/scp_dataset.py +++ b/espresso/data/scp_text_dataset.py @@ -27,23 +27,22 @@ def read_scp(self, path): scp_entries = [line.strip().split(None, 1) for line in f] self.utt_ids = [entry[0] for entry in scp_entries] self.extended_filenames = [entry[1] for entry in scp_entries] - self.size = len(scp_entries) # number of utterances - self.sizes=[] # length of each utterance + self.size = len(scp_entries) # number of utterances + self.sizes = [] # length of each utterance for filename in self.extended_filenames: try: feat = kaldi_io.read_mat(filename) - except: + except Exception: print('Failed to read feature matrix {}.'.format(filename)) raise assert feat is not None and isinstance(feat, np.ndarray) self.sizes.append(feat.shape[0]) self.sizes = np.array(self.sizes, dtype=np.int32) - self.feat_dim = feat.shape[1] # feature dimension + self.feat_dim = feat.shape[1] # feature dimension assert len(self.utt_ids) == len(self.extended_filenames) and \ len(self.utt_ids) == len(self.sizes) - def check_index(self, i): if i < 0 or i >= self.size: raise IndexError('index out of range') @@ -80,14 +79,13 @@ def __init__(self, path, ordered_prefetch=False, cache_size=4096): super().__init__(path) self.cache = None self.cache_index = {} - self.cache_size = cache_size # in terms of number of examples + self.cache_size = cache_size # in terms of number of examples self.start_search_for_next_pos_start = 0 self.ordered_indices = list(range(self.size)) - self.ordered_prefetch = ordered_prefetch # set to True ONLY if examples - # are queried in the same order - # as self.ordered_indices, and - # doing this will speed up - # search of the queried index. + # set to True ONLY if examples are queried in the same order as + # self.ordered_indices, and doing this will speed up search of the + # queried index + self.ordered_prefetch = ordered_prefetch @property def supports_prefetch(self): @@ -111,15 +109,19 @@ def __getitem__(self, i): len(self.ordered_indices), \ 'Position for next cache starting beyond the end of ordered_indices.' try: - pos_start = self.ordered_indices.index(i, - self.start_pos_for_next_cache) + pos_start = self.ordered_indices.index( + i, self.start_pos_for_next_cache, + ) except ValueError: - print('index {} not found in self.ordered_indices. Set ' - 'self.ordered_prefetch to False, and/or call self.prefetch() ' - 'with the full list of indices, and then try again.'.format(i)) + print( + 'index {} not found in self.ordered_indices. Set ' + 'self.ordered_prefetch to False, and/or call self.prefetch() ' + 'with the full list of indices, and then try again.'.format(i) + ) raise - pos_end = min(pos_start + self.cache_size, - len(self.ordered_indices)) + pos_end = min( + pos_start + self.cache_size, len(self.ordered_indices), + ) self.start_pos_for_next_cache = pos_end \ if self.ordered_prefetch else 0 total_size = 0 @@ -146,11 +148,13 @@ class ScpInMemoryDataset(ScpDataset): def __init__(self, path): super().__init__(path) self.read_data() - + def read_data(self): self.data_offsets = np.append([0], np.cumsum(self.sizes)[:-1]) - self.buffer = np.empty((sum(self.sizes), self.feat_dim), - dtype=self.dtype) + self.buffer = np.empty( + (sum(self.sizes), self.feat_dim), + dtype=self.dtype, + ) for i in range(len(self.data_offsets)): ptx = self.data_offsets[i] dst = self.buffer[ptx: ptx + self.sizes[i]] @@ -167,7 +171,7 @@ def __getitem__(self, i): return torch.from_numpy(a).float() -class TokenTextDataset(torch.utils.data.Dataset): +class AsrTextDataset(torch.utils.data.Dataset): """Takes a text file as input and binarizes it in memory at instantiation. Original lines are also kept in memory. Each line of the text file is in the format of 'utt_id tokenized_text'.""" @@ -188,12 +192,13 @@ def read_text(self, path, dictionary): utt_id, tokens = line.strip().split(None, 1) self.utt_ids.append(utt_id) self.tokens_list.append(tokens) - tensor = dictionary.encode_line(tokens, - add_if_not_exist=False, append_eos=self.append_eos).long() + tensor = dictionary.encode_line( + tokens, add_if_not_exist=False, append_eos=self.append_eos, + ).long() self.tensor_list.append(tensor) self.sizes.append(len(self.tensor_list[-1])) - self.size = len(self.utt_ids) # number of utterances + self.size = len(self.utt_ids) # number of utterances self.sizes = np.array(self.sizes, dtype=np.int32) assert len(self.utt_ids) == len(self.tokens_list) and \ @@ -226,8 +231,9 @@ def get_original_tokens(self, i): def get_original_text(self, i, dictionary, bpe_symbol=None): self.check_index(i) - return dictionary.tokens_to_sentence(self.tokens_list[i], - use_unk_sym=False, bpe_symbol=bpe_symbol) + return dictionary.tokens_to_sentence( + self.tokens_list[i], use_unk_sym=False, bpe_symbol=bpe_symbol, + ) def __len__(self): return self.size diff --git a/espresso/data/speech_dataset.py b/espresso/data/speech_dataset.py index 10b5cab34..1b6b1de9c 100644 --- a/espresso/data/speech_dataset.py +++ b/espresso/data/speech_dataset.py @@ -24,10 +24,10 @@ def merge(key, left_pad, move_eos_to_beginning=False): [s[key] for s in samples], 0.0, left_pad, ) elif key == 'target': - return data_utils.collate_tokens( - [s[key] for s in samples], - pad_idx, eos_idx, left_pad, move_eos_to_beginning, - ) + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, eos_idx, left_pad, move_eos_to_beginning, + ) else: raise ValueError('Invalid key.') @@ -66,8 +66,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): 'nsentences': len(samples), 'ntokens': ntokens, 'net_input': { - 'src_tokens': src_frames, # key name kept due to - # FairseqModel::forward(...,src_tokens,...) + 'src_tokens': src_frames, 'src_lengths': src_lengths, }, 'target': target, @@ -129,8 +128,9 @@ def _match_src_tgt(self): if self.src.utt_ids == self.tgt.utt_ids: return tgt_utt_ids_set = set(self.tgt.utt_ids) - src_indices = [i for i, id in enumerate(self.src.utt_ids) \ - if id in tgt_utt_ids_set] + src_indices = [ + i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set + ] self.src.filter_and_reorder(src_indices) self.src_sizes = np.array(self.src.sizes) try: @@ -220,4 +220,3 @@ def supports_prefetch(self): def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) - diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index 1805f3d55..00777d0d9 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -5,14 +5,13 @@ import math import torch -import torch.nn as nn -import torch.nn.functional as F -from fairseq import options, utils +from fairseq import utils from fairseq.models import FairseqIncrementalDecoder, FairseqLanguageModel -from espresso.data import TokenDictionary -from espresso.tools.utils import tokenize, lexical_prefix_tree +from espresso.data import AsrDictionary +from espresso.tools.lexical_prefix_tree import lexical_prefix_tree +from espresso.tools.utils import tokenize def _clone_cached_state(cached_state): @@ -41,8 +40,9 @@ class LookAheadWordLanguageModel(RawOutExternalLanguageModelBase): wrapper for :class:`_LookAheadWordLanguageModelDecoder`. """ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): - decoder = _LookAheadWordLanguageModelDecoder(wordlm, subword_dict, - oov_penalty, open_vocab) + decoder = _LookAheadWordLanguageModelDecoder( + wordlm, subword_dict, oov_penalty, open_vocab, + ) super().__init__(decoder) @@ -62,22 +62,22 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): 'The wrapped decoder should implement masked_copy_incremental_state()' self.oov_penalty = oov_penalty self.open_vocab = open_vocab - self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue + self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue word_dict = self.lm_decoder.dictionary - assert isinstance(word_dict, TokenDictionary) + assert isinstance(word_dict, AsrDictionary) self.word_pad_idx = word_dict.pad() self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() - assert isinstance(subword_dict, TokenDictionary) + assert isinstance(subword_dict, AsrDictionary) self.subword_space_idx = subword_dict.space() self.subword_pad_idx = subword_dict.pad() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) - tokenizer = lambda x: tokenize( - x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + def tokenizer(x): + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) def max_out_degree(node): @@ -103,7 +103,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): cached_state = utils.get_incremental_state( self.lm_decoder, incremental_state, 'cached_state') - if cached_state is None: # it is the first time step + if cached_state is None: # it is the first time step assert (prev_output_tokens == self.subword_eos_idx).all(), \ 'expecting the input to the first time step to be ' w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) @@ -118,17 +118,19 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): nodes = utils.get_incremental_state(self, incremental_state, 'nodes') assert len(nodes) == bsz w = prev_output_tokens.new([ - node.word_idx if node is not None and node.word_idx >= 0 else \ + node.word_idx if node is not None and node.word_idx >= 0 else self.word_unk_idx for node in nodes - ]).unsqueeze(-1) # B x 1 + ]).unsqueeze(-1) # B x 1 old_cached_state = _clone_cached_state(cached_state) # recompute cumsum_probs from inter-word transition probabilities # only for those whose prev_output_token is lm_probs = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), - log_probs=False, sample=None) # B x 1 x V - self.lm_decoder.masked_copy_incremental_state(incremental_state, - old_cached_state, batch_space_mask) # restore those not masked + log_probs=False, sample=None, + ) # B x 1 x V + self.lm_decoder.masked_copy_incremental_state( + incremental_state, old_cached_state, batch_space_mask, + ) # restore those not masked cumsum_probs[batch_space_mask] = \ torch.cumsum(lm_probs, dim=-1)[batch_space_mask] tokens_list = prev_output_tokens.squeeze(-1).tolist() @@ -139,7 +141,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): elif nodes[i] is not None and tokens_list[i] in nodes[i].children: # intra-word transition: go to child nodes[i] = nodes[i].children[tokens_list[i]] - else: # no path in the tree + else: # no path in the tree nodes[i] = None utils.set_incremental_state( @@ -150,7 +152,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): if self.open_vocab: # set out_probs to oov_penalty * P(|h) (case 3 in Eqn. 15) out_probs = self.oov_penalty * ( - cumsum_probs[:, :, self.word_unk_idx] - \ + cumsum_probs[:, :, self.word_unk_idx] - cumsum_probs[:, :, self.word_unk_idx - 1] ).unsqueeze(-1).repeat(1, 1, self.subword_vocab_size) # set the probability of emitting to 0 if prev_output_tokens @@ -169,11 +171,12 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_probs[batch_node_none_mask] = 1. else: # set out_probs to 0 - out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], - self.zero) + out_probs = cumsum_probs.new_full( + [bsz, 1, self.subword_vocab_size], self.zero, + ) # compute parent probabilities for those whose node is not None - sum_probs = cumsum_probs.new_full([bsz, 1], 1.) # default for root node + sum_probs = cumsum_probs.new_full([bsz, 1], 1.) # default for root node left_ranges, right_ranges, batch_node_not_root_mask = [], [], [] for node in nodes: if node is not None and node.word_set is not None: @@ -188,17 +191,20 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): right_ranges = prev_output_tokens.new(right_ranges).unsqueeze(-1) batch_node_not_root_mask = batch_space_mask.new(batch_node_not_root_mask) sum_probs[batch_node_not_root_mask] = ( - cumsum_probs[batch_node_not_root_mask].gather(-1, right_ranges) - \ + cumsum_probs[batch_node_not_root_mask].gather(-1, right_ranges) - cumsum_probs[batch_node_not_root_mask].gather(-1, left_ranges) ).squeeze(-1) # compute transition probabilities to child nodes (case 2 in Eqn. 15) - subword_idx = [[self.subword_pad_idx] * self.max_num_children \ - for _ in range(bsz)] - left_ranges = [[self.word_pad_idx] * self.max_num_children \ - for _ in range(bsz)] - right_ranges = [[self.word_pad_idx] * self.max_num_children \ - for _ in range(bsz)] + subword_idx = [ + [self.subword_pad_idx] * self.max_num_children for _ in range(bsz) + ] + left_ranges = [ + [self.word_pad_idx] * self.max_num_children for _ in range(bsz) + ] + right_ranges = [ + [self.word_pad_idx] * self.max_num_children for _ in range(bsz) + ] for i in range(bsz): node = nodes[i] if node is not None and len(node.children) > 0: @@ -210,8 +216,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): subword_idx = prev_output_tokens.new(subword_idx).unsqueeze(1) left_ranges = prev_output_tokens.new(left_ranges).unsqueeze(1) right_ranges = prev_output_tokens.new(right_ranges).unsqueeze(1) - cumsum_probs_children = (cumsum_probs.gather(-1, right_ranges) - \ - cumsum_probs.gather(-1, left_ranges)) / sum_probs.unsqueeze(-1) + cumsum_probs_children = ( + cumsum_probs.gather(-1, right_ranges) - + cumsum_probs.gather(-1, left_ranges) + ) / sum_probs.unsqueeze(-1) cumsum_probs_children[sum_probs.squeeze(-1) < self.zero, :, :] = self.zero out_probs.scatter_(-1, subword_idx, cumsum_probs_children) out_probs[:, :, self.subword_pad_idx] = self.zero @@ -225,15 +233,16 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): else: batch_node_word_end_mask.append(False) if len(word_idx) > 0: - word_idx = prev_output_tokens.new(word_idx).unsqueeze(-1) # b x 1 x 1 + word_idx = prev_output_tokens.new(word_idx).unsqueeze(-1) # b x 1 x 1 batch_node_word_end_mask = batch_space_mask.new(batch_node_word_end_mask) word_probs = torch.where( sum_probs[batch_node_word_end_mask] < self.zero, cumsum_probs.new([self.zero]), - (cumsum_probs[batch_node_word_end_mask].gather(-1, word_idx) - \ - cumsum_probs[batch_node_word_end_mask].gather(-1, word_idx - 1) + ( + cumsum_probs[batch_node_word_end_mask].gather(-1, word_idx) - + cumsum_probs[batch_node_word_end_mask].gather(-1, word_idx - 1) ).squeeze(-1).div_(sum_probs[batch_node_word_end_mask]), - ) # b x 1 + ) # b x 1 out_probs[batch_node_word_end_mask, :, self.subword_space_idx] = word_probs # take log of probs and clip it from below to avoid log(0) @@ -255,15 +264,17 @@ def reorder_incremental_state(self, incremental_state, new_order): self, incremental_state, 'cumsum_probs') if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - utils.set_incremental_state(self, incremental_state, 'cumsum_probs', - new_cumsum_probs) + utils.set_incremental_state( + self, incremental_state, 'cumsum_probs', new_cumsum_probs, + ) nodes = utils.get_incremental_state(self, incremental_state, 'nodes') if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - utils.set_incremental_state(self, incremental_state, 'nodes', - new_nodes) + utils.set_incremental_state( + self, incremental_state, 'nodes', new_nodes, + ) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" @@ -278,10 +289,12 @@ class MultiLevelLanguageModel(RawOutExternalLanguageModelBase): """A :class:`fairseq.external_language_model.RawOutExternalLanguageModelBase` wrapper for :class:`_MultiLevelLanguageModel`. """ - def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, - open_vocab=True): - decoder = _MultiLevelLanguageModel(wordlm, subwordlm, subwordlm_weight, - oov_penalty, open_vocab) + def __init__( + self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True, + ): + decoder = _MultiLevelLanguageModel( + wordlm, subwordlm, subwordlm_weight, oov_penalty, open_vocab, + ) super().__init__(decoder) @@ -292,8 +305,9 @@ class _MultiLevelLanguageModel(FairseqIncrementalDecoder): original algorithm a little bit to adapt it to the case where each tokenized sentence ends with before . """ - def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, - open_vocab=True): + def __init__( + self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, open_vocab=True, + ): super().__init__(wordlm.decoder.dictionary) assert isinstance(wordlm, FairseqLanguageModel) @@ -309,18 +323,18 @@ def __init__(self, wordlm, subwordlm, subwordlm_weight=0.8, oov_penalty=1.0, self.logzero = -10.0 word_dict = self.wordlm_decoder.dictionary - assert isinstance(word_dict, TokenDictionary) + assert isinstance(word_dict, AsrDictionary) self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() subword_dict = self.subwordlm_decoder.dictionary - assert isinstance(subword_dict, TokenDictionary) + assert isinstance(subword_dict, AsrDictionary) self.subword_space_idx = subword_dict.space() self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) - tokenizer = lambda x: tokenize( - x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + def tokenizer(x): + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) @torch.no_grad() @@ -338,7 +352,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): subwordlm_cached_state = utils.get_incremental_state( self.subwordlm_decoder, incremental_state, 'cached_state') - if wordlm_cached_state is None: # it is the first time step + if wordlm_cached_state is None: # it is the first time step assert subwordlm_cached_state is None assert (prev_output_tokens == self.subword_eos_idx).all(), \ 'expecting the input to the first time step to be ' @@ -347,28 +361,31 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): self.wordlm_decoder(w, incremental_state=incremental_state), log_probs=True, sample=None, - ) # B x 1 x V + ) # B x 1 x V sw = prev_output_tokens.new_full([bsz, 1], self.subword_eos_idx) out_logprobs = self.subwordlm_decoder.get_normalized_probs( self.subwordlm_decoder(sw, incremental_state=incremental_state), log_probs=True, sample=None, - ) * self.subwordlm_weight # B x 1 x V + ) * self.subwordlm_weight # B x 1 x V subword_cumlogprobs = out_logprobs.new_zeros(sw.size()) nodes = [self.lexroot] * bsz else: - wordlm_logprobs = utils.get_incremental_state(self, - incremental_state, 'wordlm_logprobs') - out_logprobs = utils.get_incremental_state(self, incremental_state, - 'out_logprobs') - subword_cumlogprobs = utils.get_incremental_state(self, - incremental_state, 'subword_cumlogprobs') + wordlm_logprobs = utils.get_incremental_state( + self, incremental_state, 'wordlm_logprobs', + ) + out_logprobs = utils.get_incremental_state( + self, incremental_state, 'out_logprobs', + ) + subword_cumlogprobs = utils.get_incremental_state( + self, incremental_state, 'subword_cumlogprobs', + ) nodes = utils.get_incremental_state(self, incremental_state, 'nodes') assert len(nodes) == bsz w = prev_output_tokens.new([ - node.word_idx if node is not None and node.word_idx >= 0 else \ + node.word_idx if node is not None and node.word_idx >= 0 else self.word_unk_idx for node in nodes - ]).unsqueeze(-1) # B x 1 + ]).unsqueeze(-1) # B x 1 old_wordlm_cached_state = _clone_cached_state(wordlm_cached_state) # recompute wordlm_logprobs from inter-word transition probabilities @@ -378,8 +395,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): log_probs=True, sample=None, )[batch_space_mask] - self.wordlm_decoder.masked_copy_incremental_state(incremental_state, - old_wordlm_cached_state, batch_space_mask) # restore those not masked + self.wordlm_decoder.masked_copy_incremental_state( + incremental_state, old_wordlm_cached_state, batch_space_mask, + ) # restore those not masked tokens_list = prev_output_tokens.squeeze(-1).tolist() token_idx, batch_is_child_mask = [], [] @@ -393,12 +411,12 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): nodes[i] = nodes[i].children[tokens_list[i]] token_idx.append([tokens_list[i]]) batch_is_child_mask.append(True) - else: # no path in the tree + else: # no path in the tree nodes[i] = None if self.open_vocab: token_idx.append([tokens_list[i]]) batch_is_child_mask.append(False) - token_idx = prev_output_tokens.new(token_idx).unsqueeze(-1) # b x 1 x 1 + token_idx = prev_output_tokens.new(token_idx).unsqueeze(-1) # b x 1 x 1 if self.open_vocab: subword_cumlogprobs[batch_space_mask] = 0. assert batch_not_space_mask.sum().item() == len(token_idx) @@ -420,21 +438,25 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_oov_mask = batch_not_space_mask & ~batch_is_child_mask out_logprobs[batch_oov_mask] = self.logzero - utils.set_incremental_state(self, incremental_state, 'wordlm_logprobs', - wordlm_logprobs) - utils.set_incremental_state(self, incremental_state, 'subword_cumlogprobs', - subword_cumlogprobs) + utils.set_incremental_state( + self, incremental_state, 'wordlm_logprobs', wordlm_logprobs, + ) + utils.set_incremental_state( + self, incremental_state, 'subword_cumlogprobs', subword_cumlogprobs, + ) utils.set_incremental_state(self, incremental_state, 'nodes', nodes) # apply word-level probabilies for emitting w = prev_output_tokens.new([ - node.word_idx if node is not None and node.word_idx >= 0 else \ + node.word_idx if node is not None and node.word_idx >= 0 else self.word_unk_idx for node in nodes - ]).unsqueeze(-1) # B x 1 - word_logprobs = wordlm_logprobs.gather(-1, w.unsqueeze(-1)).squeeze(-1) # B x 1 + ]).unsqueeze(-1) # B x 1 + word_logprobs = wordlm_logprobs.gather(-1, w.unsqueeze(-1)).squeeze(-1) # B x 1 batch_word_end_mask = w.ne(self.word_unk_idx) - word_logprobs += torch.where(batch_word_end_mask, - -subword_cumlogprobs, word_logprobs.new([self.log_oov_penalty])) + word_logprobs += torch.where( + batch_word_end_mask, + -subword_cumlogprobs, word_logprobs.new([self.log_oov_penalty]), + ) out_logprobs[:, :, self.subword_space_idx] = word_logprobs # set the probability of emitting to 0 if prev_output_tokens is @@ -449,8 +471,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_logprobs[batch_space_mask, :, self.subword_eos_idx] += \ wordlm_logprobs[batch_space_mask, :, self.word_eos_idx] - utils.set_incremental_state(self, incremental_state, 'out_logprobs', - out_logprobs) + utils.set_incremental_state( + self, incremental_state, 'out_logprobs', out_logprobs, + ) # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in @@ -464,15 +487,17 @@ def reorder_incremental_state(self, incremental_state, new_order): state = utils.get_incremental_state(self, incremental_state, state_name) if state is not None: new_state = state.index_select(0, new_order) - utils.set_incremental_state(self, incremental_state, state_name, - new_state) + utils.set_incremental_state( + self, incremental_state, state_name, new_state, + ) nodes = utils.get_incremental_state(self, incremental_state, 'nodes') if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - utils.set_incremental_state(self, incremental_state, 'nodes', - new_nodes) + utils.set_incremental_state( + self, incremental_state, 'nodes', new_nodes, + ) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" diff --git a/espresso/models/speech_fconv.py b/espresso/models/speech_fconv.py index 36e682adf..b3bd59891 100644 --- a/espresso/models/speech_fconv.py +++ b/espresso/models/speech_fconv.py @@ -90,20 +90,17 @@ def eval_str_nested_list_or_tuple(x, type=int): else: try: return type(x) - except: - raise ValueError - - out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, - type=int) - kernel_sizes = eval_str_nested_list_or_tuple( - args.encoder_conv_kernel_sizes, type=int) - strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, - type=int) - print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, - task.feat_in_channels)) + except TypeError: + raise TypeError + + out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 - conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, - in_channels=task.feat_in_channels) if not out_channels is None else None + conv_layers = ConvBNReLU( + out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, + ) if out_channels is not None else None fconv_encoder_input_size = task.feat_dim // task.feat_in_channels if conv_layers is not None: @@ -116,7 +113,7 @@ def eval_str_nested_list_or_tuple(x, type=int): s = stride fconv_encoder_input_size = (fconv_encoder_input_size + s - 1) // s fconv_encoder_input_size *= out_channels[-1] - + encoder = SpeechFConvEncoder( conv_layers_before=conv_layers, input_size=fconv_encoder_input_size, @@ -169,7 +166,7 @@ def __init__( self.conv_layers_before = conv_layers_before self.fc0 = Linear(input_size, embed_dim, dropout=dropout) \ if input_size != embed_dim else None - + convolutions = extend_conv_spec(convolutions) in_channels = convolutions[0][0] self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) @@ -221,8 +218,7 @@ def forward(self, src_tokens, src_lengths): padding elements of shape `(batch, src_len)` """ if self.conv_layers_before is not None: - x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, - src_lengths) + x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: x, encoder_padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) @@ -306,8 +302,10 @@ def masked_copy_incremental_state(self, incremental_state, another_state, mask): def mask_copy_state(state, another_state): if isinstance(state, list): assert isinstance(another_state, list) and len(state) == len(another_state) - return [mask_copy_state(state_i, another_state_i) \ - for state_i, another_state_i in zip(state, another_state)] + return [ + mask_copy_state(state_i, another_state_i) + for state_i, another_state_i in zip(state, another_state) + ] if state is not None: assert state.size(0) == mask.size(0) and another_state is not None and \ state.size() == another_state.size() @@ -324,12 +322,15 @@ def mask_copy_state(state, another_state): @register_model_architecture('speech_fconv', 'speech_fconv') def base_architecture(args): - args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', - '[64, 64, 128, 128]') - args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', - '[(3, 3), (3, 3), (3, 3), (3, 3)]') - args.encoder_conv_strides = getattr(args, 'encoder_conv_strides', - '[(1, 1), (2, 2), (1, 1), (2, 2)]') + args.encoder_conv_channels = getattr( + args, 'encoder_conv_channels', '[64, 64, 128, 128]', + ) + args.encoder_conv_kernel_sizes = getattr( + args, 'encoder_conv_kernel_sizes', '[(3, 3), (3, 3), (3, 3), (3, 3)]', + ) + args.encoder_conv_strides = getattr( + args, 'encoder_conv_strides', '[(1, 1), (2, 2), (1, 1), (2, 2)]', + ) args.dropout = getattr(args, 'dropout', 0.1) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20') diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index df1b5b215..ecb5400ec 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -9,6 +9,7 @@ from fairseq import options, utils, checkpoint_utils from fairseq.models import ( + FairseqDecoder, FairseqEncoder, FairseqIncrementalDecoder, FairseqLanguageModel, @@ -17,7 +18,6 @@ register_model_architecture, ) from fairseq.models.lstm import ( - AttentionLayer, Embedding, LSTM, LSTMCell, @@ -26,7 +26,7 @@ from fairseq.modules import AdaptiveSoftmax from espresso.modules import speech_attention -from espresso.tasks.speech_recognition import SpeechRecognitionTask +from espresso.tasks.speech_recognition import SpeechRecognitionEspressoTask import espresso.tools.utils as speech_utils @@ -80,7 +80,7 @@ def add_args(parser): 'layers (starting from the 2nd layer), i.e., the actual ' 'output of such layer is the sum of its input and output') parser.add_argument('--attention-type', type=str, metavar='STR', - choices=['bahdanau','luong'], + choices=['bahdanau', 'luong'], help='attention type') parser.add_argument('--attention-dim', type=int, metavar='N', help='attention dimension') @@ -154,20 +154,17 @@ def eval_str_nested_list_or_tuple(x, type=int): else: try: return type(x) - except: - raise ValueError - - out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, - type=int) - kernel_sizes = eval_str_nested_list_or_tuple( - args.encoder_conv_kernel_sizes, type=int) - strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, - type=int) - print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, - task.feat_in_channels)) + except TypeError: + raise TypeError + + out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 - conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, - in_channels=task.feat_in_channels) if not out_channels is None else None + conv_layers = ConvBNReLU( + out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, + ) if out_channels is not None else None rnn_encoder_input_size = task.feat_dim // task.feat_in_channels if conv_layers is not None: @@ -224,9 +221,10 @@ def eval_str_nested_list_or_tuple(x, type=int): def max_positions(self): """Maximum length supported by the model.""" - return (self.encoder.max_positions(), - self.decoder.max_positions() if self.pretrained_lm is None else \ - min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) + return ( + self.encoder.max_positions(), + self.decoder.max_positions() if self.pretrained_lm is None else + min(self.decoder.max_positions(), self.pretrained_lm.max_positions()), ) def max_decoder_positions(self): @@ -294,7 +292,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.is_wordlm and hasattr(task, 'word_dictionary'): dictionary = task.word_dictionary - elif isinstance(task, SpeechRecognitionTask): + elif isinstance(task, SpeechRecognitionEspressoTask): dictionary = task.target_dictionary else: dictionary = task.source_dictionary @@ -372,8 +370,9 @@ def output_lengths(self, in_lengths): def forward(self, src, src_lengths): # B X T X C -> B X (input channel num) x T X (C / input channel num) - x = src.view(src.size(0), src.size(1), self.in_channels, - src.size(2) // self.in_channels).transpose(1, 2) + x = src.view( + src.size(0), src.size(1), self.in_channels, src.size(2) // self.in_channels, + ).transpose(1, 2) for conv, bn in zip(self.convolutions, self.batchnorms): x = F.relu(bn(conv(x))) # B X (output channel num) x T X C' -> B X T X (output channel num) X C' @@ -396,7 +395,7 @@ def __init__( num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, residual=False, left_pad=False, pretrained_embed=None, padding_value=0., ): - super().__init__(None) # no src dictionary + super().__init__(None) # no src dictionary self.conv_layers_before = conv_layers_before self.num_layers = num_layers self.dropout_in = dropout_in @@ -435,8 +434,7 @@ def forward(self, src_tokens, src_lengths): ) if self.conv_layers_before is not None: - x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, - src_lengths) + x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: x, padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) @@ -452,7 +450,7 @@ def forward(self, src_tokens, src_lengths): h0, c0 = x.new_zeros(*state_size), x.new_zeros(*state_size) for i in range(len(self.lstm)): - if self.residual and i > 0: # residual connection starts from the 2nd layer + if self.residual and i > 0: # residual connection starts from the 2nd layer prev_x = x # pack embedded source tokens into a PackedSequence packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) @@ -462,7 +460,7 @@ def forward(self, src_tokens, src_lengths): # unpack outputs and apply dropout x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value) - if i < len(self.lstm) - 1: # not applying dropout for the last layer + if i < len(self.lstm) - 1: # not applying dropout for the last layer x = F.dropout(x, p=self.dropout_out, training=self.training) x = x + prev_x if self.residual and i > 0 else x assert list(x.size()) == [seqlen, bsz, self.output_units] @@ -529,11 +527,13 @@ def __init__( if attn_type is None or attn_type.lower() == 'none': self.attention = None elif attn_type.lower() == 'bahdanau': - self.attention = speech_attention.BahdanauAttention(hidden_size, - encoder_output_units, attn_dim) + self.attention = speech_attention.BahdanauAttention( + hidden_size, encoder_output_units, attn_dim, + ) elif attn_type.lower() == 'luong': - self.attention = speech_attention.LuongAttention(hidden_size, - encoder_output_units) + self.attention = speech_attention.LuongAttention( + hidden_size, encoder_output_units, + ) else: raise ValueError('unrecognized attention type.') if hidden_size + encoder_output_units != out_embed_dim: @@ -566,7 +566,8 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, return self.output_layer(x), attn_scores def extract_features( - self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused, + ): """ Similar to *forward* but only return features. @@ -600,10 +601,8 @@ def extract_features( prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) - prev_hiddens = [x.new_zeros(bsz, self.hidden_size) \ - for i in range(num_layers)] - prev_cells = [x.new_zeros(bsz, self.hidden_size) \ - for i in range(num_layers)] + prev_hiddens = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)] + prev_cells = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)] input_feed = x.new_zeros(bsz, self.encoder_output_units) \ if self.attention is not None else None @@ -618,14 +617,15 @@ def extract_features( for i, rnn in enumerate(self.layers): # recurrent cell hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) - if self.residual and i > 0: # residual connection starts from the 2nd layer + if self.residual and i > 0: # residual connection starts from the 2nd layer prev_layer_hidden = input[:, :hidden.size(1)] # compute and apply attention using the 1st layer's hidden state if self.attention is not None: if i == 0: - context, attn_scores[:, j, :], _ = self.attention(hidden, - encoder_outs, encoder_padding_mask) + context, attn_scores[:, j, :], _ = self.attention( + hidden, encoder_outs, encoder_padding_mask, + ) # hidden state concatenated with context vector becomes the # input to the next layer @@ -709,8 +709,10 @@ def masked_copy_incremental_state(self, incremental_state, another_cached_state, def mask_copy_state(state, another_state): if isinstance(state, list): assert isinstance(another_state, list) and len(state) == len(another_state) - return [mask_copy_state(state_i, another_state_i) \ - for state_i, another_state_i in zip(state, another_state)] + return [ + mask_copy_state(state_i, another_state_i) + for state_i, another_state_i in zip(state, another_state) + ] if state is not None: assert state.size(0) == mask.size(0) and another_state is not None and \ state.size() == another_state.size() @@ -719,7 +721,7 @@ def mask_copy_state(state, another_state): return torch.where(mask_unsqueezed, state, another_state) else: assert another_state is None - return None + return None new_state = tuple(map(mask_copy_state, cached_state, another_cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) @@ -749,8 +751,9 @@ def Convolution2d(in_channels, out_channels, kernel_size, stride): stride = (stride, stride) assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1 padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - m = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, \ - padding=padding) + m = nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, + ) return m @@ -813,12 +816,15 @@ def lstm_wordlm_wsj(args): @register_model_architecture('speech_lstm', 'speech_lstm') def base_architecture(args): args.dropout = getattr(args, 'dropout', 0.4) - args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', - '[64, 64, 128, 128]') - args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', - '[(3, 3), (3, 3), (3, 3), (3, 3)]') - args.encoder_conv_strides = getattr(args, 'encoder_conv_strides', - '[(1, 1), (2, 2), (1, 1), (2, 2)]') + args.encoder_conv_channels = getattr( + args, 'encoder_conv_channels', '[64, 64, 128, 128]', + ) + args.encoder_conv_kernel_sizes = getattr( + args, 'encoder_conv_kernel_sizes', '[(3, 3), (3, 3), (3, 3), (3, 3)]', + ) + args.encoder_conv_strides = getattr( + args, 'encoder_conv_strides', '[(1, 1), (2, 2), (1, 1), (2, 2)]', + ) args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 320) args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) args.encoder_rnn_bidirectional = getattr(args, 'encoder_rnn_bidirectional', True) diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 3a527ea3e..bba452fc8 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -3,8 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import math - import torch import torch.nn as nn import torch.nn.functional as F @@ -108,20 +106,17 @@ def eval_str_nested_list_or_tuple(x, type=int): else: try: return type(x) - except: - raise ValueError - - out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, - type=int) - kernel_sizes = eval_str_nested_list_or_tuple( - args.encoder_conv_kernel_sizes, type=int) - strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, - type=int) - print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, - task.feat_in_channels)) + except TypeError: + raise TypeError + + out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 - conv_layers = ConvBNReLU(out_channels, kernel_sizes, strides, - in_channels=task.feat_in_channels) if not out_channels is None else None + conv_layers = ConvBNReLU( + out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, + ) if out_channels is not None else None transformer_encoder_input_size = task.feat_dim // task.feat_in_channels if conv_layers is not None: @@ -135,16 +130,18 @@ def eval_str_nested_list_or_tuple(x, type=int): transformer_encoder_input_size = \ (transformer_encoder_input_size + s - 1) // s transformer_encoder_input_size *= out_channels[-1] - - encoder = cls.build_encoder(args, conv_layers_before=conv_layers, - input_size=transformer_encoder_input_size) + + encoder = cls.build_encoder( + args, conv_layers_before=conv_layers, input_size=transformer_encoder_input_size, + ) decoder = cls.build_decoder(args, dict, decoder_embed_tokens) return SpeechTransformerModel(encoder, decoder) @classmethod def build_encoder(cls, args, conv_layers_before=None, input_size=83): - return SpeechTransformerEncoder(args, - conv_layers_before=conv_layers_before, input_size=input_size) + return SpeechTransformerEncoder( + args, conv_layers_before=conv_layers_before, input_size=input_size, + ) @classmethod def build_decoder(cls, args, dict, embed_tokens): @@ -206,8 +203,7 @@ def forward(self, src_tokens, src_lengths): padding elements of shape `(batch, src_len)` """ if self.conv_layers_before is not None: - x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, - src_lengths) + x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: x, encoder_padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) @@ -258,14 +254,18 @@ class SpeechTransformerDecoder(TransformerDecoder): def masked_copy_incremental_state(self, incremental_state, another_cached_state, mask): pass + @register_model_architecture('speech_transformer', 'speech_transformer') def base_architecture(args): - args.encoder_conv_channels = getattr(args, 'encoder_conv_channels', - '[64, 64, 128, 128]') - args.encoder_conv_kernel_sizes = getattr(args, 'encoder_conv_kernel_sizes', - '[(3, 3), (3, 3), (3, 3), (3, 3)]') - args.encoder_conv_strides = getattr(args, 'encoder_conv_strides', - '[(1, 1), (2, 2), (1, 1), (2, 2)]') + args.encoder_conv_channels = getattr( + args, 'encoder_conv_channels', '[64, 64, 128, 128]', + ) + args.encoder_conv_kernel_sizes = getattr( + args, 'encoder_conv_kernel_sizes', '[(3, 3), (3, 3), (3, 3), (3, 3)]', + ) + args.encoder_conv_strides = getattr( + args, 'encoder_conv_strides', '[(1, 1), (2, 2), (1, 1), (2, 2)]', + ) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512) args.encoder_layers = getattr(args, 'encoder_layers', 6) diff --git a/espresso/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py index b9e39d00d..9a0370a57 100644 --- a/espresso/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -3,13 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import * +from typing import Any, Dict, List import torch from fairseq.models import FairseqLanguageModel, FairseqIncrementalDecoder from fairseq import utils -from espresso.data import TokenDictionary +from espresso.data import AsrDictionary from espresso.models.external_language_model import RawOutExternalLanguageModelBase from espresso.tools.tensorized_prefix_tree import TensorizedPrefixTree from espresso.tools.utils import tokenize @@ -33,7 +33,7 @@ class TensorizedLookaheadLanguageModel(RawOutExternalLanguageModelBase): """ def __init__(self, word_lm: FairseqLanguageModel, - subword_dict: TokenDictionary, + subword_dict: AsrDictionary, oov_penalty: float = 1e-4, open_vocab: bool = True ): @@ -53,7 +53,7 @@ class _TensorizedLookaheadLanguageModelDecoder(FairseqIncrementalDecoder): """ def __init__(self, word_lm: FairseqLanguageModel, - subword_dict: TokenDictionary, + subword_dict: AsrDictionary, oov_penalty: float = 1e-4, open_vocab: bool = True): super().__init__(word_lm.decoder.dictionary) @@ -67,7 +67,7 @@ def __init__(self, self.open_vocab = open_vocab self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue - word_dict: TokenDictionary = self.lm_decoder.dictionary + word_dict: AsrDictionary = self.lm_decoder.dictionary self.word_pad_idx = word_dict.pad() self.word_eos_idx = word_dict.eos() self.word_unk_idx = word_dict.unk() @@ -77,8 +77,8 @@ def __init__(self, self.subword_eos_idx = subword_dict.eos() self.subword_vocab_size = len(subword_dict) - tokenizer: Callable[[str], List[str]] = \ - lambda x: tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + def tokenizer(x: str) -> List[str]: + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') self.tree = TensorizedPrefixTree.build(word_dict, subword_dict, tokenizer) assert self.tree.max_out_degree() <= self.subword_vocab_size @@ -145,7 +145,7 @@ def forward(self, if self.open_vocab: # set out_probs to oov_penalty * P(|h) (case 3 in Eqn. 15) out_probs = self.oov_penalty * ( - cumsum_probs[:, :, self.word_unk_idx] - \ + cumsum_probs[:, :, self.word_unk_idx] - cumsum_probs[:, :, self.word_unk_idx - 1] ).unsqueeze(-1).repeat(1, 1, self.subword_vocab_size) @@ -172,7 +172,7 @@ def forward(self, sum_probs = torch.where( batch_node_not_root_mask, (cumsum_probs.squeeze(1).gather(-1, right_ranges.unsqueeze(-1)) - - cumsum_probs.squeeze(1).gather(-1, left_ranges.unsqueeze(-1))).squeeze(-1), + cumsum_probs.squeeze(1).gather(-1, left_ranges.unsqueeze(-1))).squeeze(-1), cumsum_probs.new([1.0]) ) # R[Batch] @@ -204,7 +204,7 @@ def forward(self, sum_probs < self.zero, cumsum_probs.new([self.zero]), ( - cumsum_probs.squeeze(1).gather(-1, word_idx.unsqueeze(-1)) - \ + cumsum_probs.squeeze(1).gather(-1, word_idx.unsqueeze(-1)) - cumsum_probs.squeeze(1).gather(-1, word_idx.unsqueeze(-1) - 1) ).squeeze(-1) / sum_probs ) # R[Batch] @@ -232,14 +232,16 @@ def reorder_incremental_state(self, incremental_state, new_order): self, incremental_state, 'cumsum_probs') if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - utils.set_incremental_state(self, incremental_state, 'cumsum_probs', - new_cumsum_probs) + utils.set_incremental_state( + self, incremental_state, 'cumsum_probs', new_cumsum_probs, + ) nodes = utils.get_incremental_state(self, incremental_state, 'nodes') if nodes is not None: new_nodes = nodes.index_select(0, new_order) - utils.set_incremental_state(self, incremental_state, 'nodes', - new_nodes) + utils.set_incremental_state( + self, incremental_state, 'nodes', new_nodes, + ) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" @@ -254,4 +256,3 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta def output_layer(self, features, **kwargs): pass - diff --git a/espresso/modules/speech_attention.py b/espresso/modules/speech_attention.py index 002964eb1..fcba56e91 100644 --- a/espresso/modules/speech_attention.py +++ b/espresso/modules/speech_attention.py @@ -9,8 +9,6 @@ from torch.nn import Parameter import torch.nn.functional as F -from fairseq import utils - class BaseAttention(nn.Module): """Base class for attention layers.""" @@ -62,12 +60,13 @@ def reset_parameters(self): def forward(self, query, value, key_padding_mask=None, state=None): # projected_query: 1 x bsz x embed_dim projected_query = self.query_proj(query).unsqueeze(0) - key = self.value_proj(value) # len x bsz x embed_dim + key = self.value_proj(value) # len x bsz x embed_dim if self.normalize: # normed_v = g * v / ||v|| normed_v = self.g * self.v / torch.norm(self.v) - attn_scores = (normed_v * torch.tanh(projected_query + key + \ - self.b)).sum(dim=2) # len x bsz + attn_scores = ( + normed_v * torch.tanh(projected_query + key + self.b) + ).sum(dim=2) # len x bsz else: attn_scores = self.v * torch.tanh(projected_query + key).sum(dim=2) @@ -76,7 +75,7 @@ def forward(self, query, value, key_padding_mask=None, state=None): key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back - attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz + attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz # sum weighted value. context: bsz x value_dim context = (attn_scores.unsqueeze(2) * value).sum(dim=0) @@ -104,7 +103,7 @@ def reset_parameters(self): def forward(self, query, value, key_padding_mask=None, state=None): query = query.unsqueeze(1) # bsz x 1 x query_dim - key = self.value_proj(value).transpose(0, 1) # bsz x len x query_dim + key = self.value_proj(value).transpose(0, 1) # bsz x len x query_dim attn_scores = torch.bmm(query, key.transpose(1, 2)).squeeze(1) attn_scores = attn_scores.transpose(0, 1) # len x bsz if self.scale: @@ -115,11 +114,10 @@ def forward(self, query, value, key_padding_mask=None, state=None): key_padding_mask, float('-inf'), ).type_as(attn_scores) # FP16 support: cast to float and back - attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz + attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz # sum weighted value. context: bsz x value_dim context = (attn_scores.unsqueeze(2) * value).sum(dim=0) next_state = attn_scores return context, attn_scores, next_state - diff --git a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 2435f3ce1..aa5a9e4ac 100644 --- a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -5,7 +5,7 @@ import torch.optim.lr_scheduler -from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler +from fairseq.optim.lr_scheduler import register_lr_scheduler from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateau @@ -24,7 +24,7 @@ def __init__(self, args, optimizer): self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=0, factor=args.lr_shrink, threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0]) - + @staticmethod def add_args(parser): """Add arguments to the parser for this LR scheduler.""" @@ -35,7 +35,7 @@ def add_args(parser): parser.add_argument('--start-reduce-lr-epoch', default=0, type=int, metavar='N', help='start to reduce lr from the specified epoch') # fmt: on - + def step(self, epoch, val_loss=None): if epoch < self.args.start_reduce_lr_epoch: self.lr_scheduler.last_epoch = epoch diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 5ca59aa6b..2156d19c3 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -15,7 +15,6 @@ from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel -from fairseq.utils import import_user_module from espresso.models.external_language_model import MultiLevelLanguageModel from espresso.models.tensorized_lookahead_language_model import TensorizedLookaheadLanguageModel @@ -54,16 +53,20 @@ def main(args): if hasattr(m, 'is_wordlm') and m.is_wordlm: # assume subword LM comes before word LM if isinstance(models[i - 1], FairseqLanguageModel): - models[i-1] = MultiLevelLanguageModel(m, models[i-1], + models[i-1] = MultiLevelLanguageModel( + m, models[i-1], subwordlm_weight=args.subwordlm_weight, oov_penalty=args.oov_penalty, - open_vocab=not args.disable_open_vocab) + open_vocab=not args.disable_open_vocab, + ) del models[i] print('| LM fusion with Multi-level LM') else: - models[i] = TensorizedLookaheadLanguageModel(m, dict, + models[i] = TensorizedLookaheadLanguageModel( + m, dict, oov_penalty=args.oov_penalty, - open_vocab=not args.disable_open_vocab) + open_vocab=not args.disable_open_vocab, + ) print('| LM fusion with Look-ahead Word LM') # assume subword LM comes after E2E models elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel): @@ -89,8 +92,8 @@ def main(args): max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), - *[model.max_positions() if hasattr(model, 'encoder') \ - else (None, model.max_positions()) for model in models] + *[model.max_positions() if hasattr(model, 'encoder') + else (None, model.max_positions()) for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, @@ -102,7 +105,7 @@ def main(args): # Initialize generator if args.match_source_len: print('| The option match_source_len is not applicable to ' - 'speech recognition. Ignoring it.') + 'speech recognition. Ignoring it.') gen_timer = StopwatchMeter() generator = task.build_generator(args) @@ -122,8 +125,9 @@ def main(args): prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens, - lm_weight=args.lm_weight) + hypos = task.inference_step( + generator, models, sample, prefix_tokens, lm_weight=args.lm_weight, + ) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) @@ -142,13 +146,14 @@ def main(args): if has_target: target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) if not args.quiet: - target_sent = dict.tokens_to_sentence(target_str, - use_unk_sym=False, bpe_symbol=args.remove_bpe) + target_sent = dict.tokens_to_sentence( + target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, + ) print('T-{}\t{}'.format(utt_id, target_sent)) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): - hypo_str = dict.string(hypo['tokens'].int().cpu()) # not removing bpe at this point + hypo_str = dict.string(hypo['tokens'].int().cpu()) # not removing bpe at this point if not args.quiet or i == 0: hypo_sent = dict.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 6b46df933..fa84a59b3 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -253,7 +253,7 @@ def validate(args, trainer, task, epoch_itr, subsets): for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', - 'sample_size', 'word_count', 'char_count']: + 'sample_size', 'word_count', 'char_count']: continue if k == 'word_error': extra_meters['wer'].update( diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 8547423c3..444ed66af 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -8,11 +8,11 @@ import os from fairseq import tokenizer +from fairseq.data import TruncatedDictionary from fairseq.tasks import register_task - from fairseq.tasks.language_modeling import LanguageModelingTask -from espresso.data import TokenDictionary +from espresso.data import AsrDictionary @register_task("language_modeling_for_asr") @@ -21,9 +21,9 @@ class LanguageModelingForASRTask(LanguageModelingTask): Train a language model. Args: - dictionary (~fairseq.data.TokenDictionary): the dictionary for the input of + dictionary (~fairseq.data.AsrDictionary): the dictionary for the input of the language model - output_dictionary (~fairseq.data.TokenDictionary): the dictionary for the + output_dictionary (~fairseq.data.AsrDictionary): the dictionary for the output of the language model. In most cases it will be the same as *dictionary*, but could possibly be a more limited version of the dictionary (if ``--output-dictionary-size`` is used). @@ -69,7 +69,7 @@ def load_dictionary(cls, filename): Args: filename (str): the filename """ - return TokenDictionary.load(filename) + return AsrDictionary.load(filename) @classmethod def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): @@ -85,9 +85,9 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores). """ - d = TokenDictionary() + d = AsrDictionary() for filename in filenames: - TokenDictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) + AsrDictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers) d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) return d @@ -105,7 +105,7 @@ def setup_task(cls, args, **kwargs): assert len(paths) > 0 dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \ else args.dict - dictionary = TokenDictionary.load(dict_path) + dictionary = AsrDictionary.load(dict_path) print("| dictionary: {} types".format(len(dictionary))) output_dictionary = dictionary if args.output_dictionary_size >= 0: diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 16f31f426..7258b9844 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -3,35 +3,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch - -import itertools import os -import re + +import torch from fairseq import options -from fairseq.data import ( - ConcatDataset, - data_utils, -) +from fairseq.data import ConcatDataset from fairseq.tasks import FairseqTask, register_task from espresso.data import ( + AsrDictionary, + AsrTextDataset, ScpCachedDataset, SpeechDataset, - TokenDictionary, - TokenTextDataset, ) -@register_task('speech_recognition') -class SpeechRecognitionTask(FairseqTask): +@register_task('speech_recognition_espresso') +class SpeechRecognitionEspressoTask(FairseqTask): """ Transcribe from speech (source) to token text (target). Args: - dict (~fairseq.data.TokenDictionary): dictionary for the output tokens + dict (~fairseq.data.AsrDictionary): dictionary for the output tokens .. note:: @@ -102,7 +97,7 @@ def load_dictionary(cls, filename, non_lang_syms=None): filename (str): the filename non_lang_syms (str): non_lang_syms filename """ - return TokenDictionary.load(filename, f_non_lang_syms=non_lang_syms) + return AsrDictionary.load(filename, f_non_lang_syms=non_lang_syms) @classmethod def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): @@ -165,7 +160,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): text_files = self.args.valid_text_files elif split == 'test': feat_files = self.args.test_feat_files - text_files = self.args.test_text_files # can be empty + text_files = self.args.test_text_files # can be empty if text_files is None: text_files = [None] * len(feat_files) elif split == 'train_subset': @@ -178,11 +173,11 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): file_pairs = zip(feat_files, text_files) for feat, text in file_pairs: assert ScpCachedDataset.exists(feat), feat + ' does not exists' - assert text is None or TokenTextDataset.exists(text), text + ' does not exists' + assert text is None or AsrTextDataset.exists(text), text + ' does not exists' src_datasets.append(ScpCachedDataset(feat, ordered_prefetch=True)) print('| {} {} examples'.format(feat, len(src_datasets[-1]))) if text is not None: - tgt_datasets.append(TokenTextDataset(text, self.dict)) + tgt_datasets.append(AsrTextDataset(text, self.dict)) print('| {} {} examples'.format(text, len(tgt_datasets[-1]))) if not combine: @@ -228,7 +223,7 @@ def build_generator(self, args): if args.score_reference: args.score_reference = False print('| --score-reference is not applicable to speech recognition,' - ' ignoring it.') + ' ignoring it.') from fairseq.sequence_generator import SequenceGenerator return SequenceGenerator( self.target_dictionary, @@ -253,11 +248,11 @@ def build_generator(self, args): def build_dataset_for_inference(self, src_tokens, src_lengths): return SpeechDataset(src_tokens, src_lengths) - def inference_step(self, generator, models, sample, prefix_tokens=None, - lm_weight=0.0): + def inference_step(self, generator, models, sample, prefix_tokens=None, lm_weight=0.0): with torch.no_grad(): - return generator.generate(models, sample, prefix_tokens=prefix_tokens, - lm_weight=lm_weight) + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, lm_weight=lm_weight, + ) def max_positions(self): """Return the max sentence length allowed by the task.""" @@ -265,10 +260,10 @@ def max_positions(self): @property def target_dictionary(self): - """Return the target :class:`~fairseq.data.TokenDictionary`.""" + """Return the target :class:`~fairseq.data.AsrDictionary`.""" return self.dict @property def word_dictionary(self): - """Return the target :class:`~fairseq.data.TokenDictionary`.""" + """Return the target :class:`~fairseq.data.AsrDictionary`.""" return self.word_dict diff --git a/espresso/tools/compute_wer.py b/espresso/tools/compute_wer.py index 62a44e423..4b6fae77c 100755 --- a/espresso/tools/compute_wer.py +++ b/espresso/tools/compute_wer.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import argparse -import sys, re +import re +import sys from collections import Counter from espresso.tools.utils import edit_distance @@ -52,8 +53,10 @@ def main(args): assert m is not None word_filters.append([m.group(1), m.group(2)]) else: - print('Unsupported pattern: "{}", ignored'.format(line), - file=sys.stderr) + print( + 'Unsupported pattern: "{}", ignored'.format(line), + file=sys.stderr, + ) refs = {} with open(args.ref_text, 'r', encoding='utf-8') as f: @@ -68,7 +71,7 @@ def main(args): utt_id, text = line.strip().split(None, 1) assert utt_id in refs, utt_id ref, hyp = refs[utt_id], text - + # filter words according to word_filters (support re.sub only) for pattern, repl in word_filters: ref = re.sub(pattern, repl, ref) @@ -82,8 +85,9 @@ def main(args): wer_counter += counter assert wer_counter['words'] > 0 - wer = float(wer_counter['sub'] + wer_counter['ins'] + \ - wer_counter['del']) / wer_counter['words'] * 100 + wer = float( + wer_counter['sub'] + wer_counter['ins'] + wer_counter['del'] + ) / wer_counter['words'] * 100 sub = float(wer_counter['sub']) / wer_counter['words'] * 100 ins = float(wer_counter['ins']) / wer_counter['words'] * 100 dlt = float(wer_counter['del']) / wer_counter['words'] * 100 @@ -95,4 +99,4 @@ def main(args): if __name__ == '__main__': parser = get_parser() args = parser.parse_args() - main(args) + main(args) diff --git a/espresso/tools/lexical_prefix_tree.py b/espresso/tools/lexical_prefix_tree.py new file mode 100644 index 000000000..79a0e8f38 --- /dev/null +++ b/espresso/tools/lexical_prefix_tree.py @@ -0,0 +1,64 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, List + +from espresso.data import AsrDictionary + + +def lexical_prefix_tree( + word_dict: AsrDictionary, + subword_dict: AsrDictionary, + subword_tokenizer: Callable[[str], List[str]] = None +): + """Build a lexical prefix tree for words. + + Args: + word_dict: an instance of :class:`fairseq.data.AsrDictionary`. + subword_dict: an instance of :class:`fairseq.data.AsrDictionary`. + subword_tokenizer (callable): a function that takes a word string as its + only one argument, and returns a list of subwords as a result of + tokenization. + + Return: + root (Node): the root of the prefix tree, where each node has the fields: + ('children': Dict[int,Node], 'word_idx': int, 'word_set': Tuple[int]). + 'children' is subword_idx -> node, and 'word_set' is (first-1, last), + where [first, last] is the range of the word indexes (inclusive) in + the word dictionary who share the same prefix at that node. + We assume words in the word dictionary are in lexical order. + """ + + class Node(object): + def __init__(self, children={}, word_idx=-1, word_set=None): + self.children = children + self.word_idx = word_idx + self.word_set = word_set + + special_symbols = [word_dict.pad(), word_dict.eos(), word_dict.unk()] + assert 0 in special_symbols # to ensure widx - 1 >= 0 + root = Node({}, -1, None) + for widx in range(len(word_dict)): + if widx not in special_symbols: # skip , , + # tokenize a word into a list of subwords + subwords = subword_tokenizer(word_dict[widx]) \ + if subword_tokenizer is not None else list(word_dict[widx]) + if any(subword_dict.index(s) == subword_dict.unk() for s in subwords): + # skip words containing any unknown subwords + continue + children = root.children + for i, s in enumerate(subwords): + sidx = subword_dict.index(s) + if sidx not in children: # make a new node + children[sidx] = Node({}, -1, (widx - 1, widx)) + else: + children[sidx].word_set = ( + min(children[sidx].word_set[0], widx - 1), + max(children[sidx].word_set[1], widx) + ) + if i == len(subwords) - 1: # if word end, set word_idx + children[sidx].word_idx = widx + children = children[sidx].children # move to children + return root diff --git a/espresso/tools/tensorized_prefix_tree.py b/espresso/tools/tensorized_prefix_tree.py index 77bf37697..ed76aa00c 100644 --- a/espresso/tools/tensorized_prefix_tree.py +++ b/espresso/tools/tensorized_prefix_tree.py @@ -3,14 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os, re import numpy as np -from typing import * +from typing import Callable, List import torch -from espresso.data import TokenDictionary -from espresso.tools.utils import lexical_prefix_tree +from espresso.data import AsrDictionary +from espresso.tools.lexical_prefix_tree import lexical_prefix_tree class TensorizedPrefixTree: @@ -43,8 +42,8 @@ def to_cuda(self, device): @staticmethod def build( - word_dict: TokenDictionary, - subword_dict: TokenDictionary, + word_dict: AsrDictionary, + subword_dict: AsrDictionary, subword_tokenizer: Callable[[str], List[str]] = None ): """ diff --git a/espresso/tools/text2token.py b/espresso/tools/text2token.py index e020ba26f..455d8ebe1 100755 --- a/espresso/tools/text2token.py +++ b/espresso/tools/text2token.py @@ -39,8 +39,11 @@ def main(args): with (open(args.text, 'r', encoding='utf-8') if args.text else sys.stdin) as f: for line in f: entry = line.rstrip().split() - tokenized = tokenize(' '.join(entry[args.skip_ncols:]), - space=args.space, non_lang_syms=nls) + tokenized = tokenize( + ' '.join(entry[args.skip_ncols:]), + space=args.space, + non_lang_syms=nls, + ) if args.skip_ncols > 0: if args.ends_with_space: print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized + ' ' + args.space) diff --git a/espresso/tools/text2vocabulary.py b/espresso/tools/text2vocabulary.py index ee3d314c0..50681b967 100755 --- a/espresso/tools/text2vocabulary.py +++ b/espresso/tools/text2vocabulary.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import argparse -import os, sys +import os +import sys from collections import Counter @@ -116,4 +117,4 @@ def main(args): if __name__ == '__main__': parser = get_parser() args = parser.parse_args() - main(args) + main(args) diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index cdf124d7f..bf2cfbeaa 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -3,17 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os, re +import os +import re import numpy as np from collections import Counter -from typing import Callable, List import torch from fairseq import utils -from espresso.data import TokenDictionary - def tokenize(sent, space='', non_lang_syms=None): assert isinstance(sent, str) @@ -39,6 +37,7 @@ def tokenize(sent, space='', non_lang_syms=None): tokens = [space if token == ' ' else token for token in tokens] return ' '.join(tokens) + def collate_frames(values, pad_value=0.0, left_pad=False): """Convert a list of 2d tensor into a padded 3d tensor.""" assert values[0].dim() == 2, "expected 2, got " + str(values[0].dim) @@ -53,20 +52,25 @@ def collate_frames(values, pad_value=0.0, left_pad=False): dst.copy_(v) return res + def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() else: assert sequence_length.data.max().item() <= utils.item(max_len) batch_size = sequence_length.size(0) - seq_range = torch.arange(0, max_len).to(device=sequence_length.device, - dtype=sequence_length.dtype) + seq_range = torch.arange(0, max_len).to( + device=sequence_length.device, + dtype=sequence_length.dtype, + ) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) return seq_range_expand < seq_length_expand -def convert_padding_direction(src_frames, src_lengths, right_to_left=False, - left_to_right=False): + +def convert_padding_direction( + src_frames, src_lengths, right_to_left=False, left_to_right=False, +): """Counterpart of :func:`~fairseq.utils.convert_padding_direction`, operating on 3d tensors of size B x T x C. Note that this function is unware of whether it has already been right padded or left padded (since any real @@ -87,6 +91,7 @@ def convert_padding_direction(src_frames, src_lengths, right_to_left=False, index = torch.remainder(range + num_pads, max_len) return src_frames.gather(1, index) + def plot_attention(attention, hypo_sent, utt_id, save_dir): """This function plots the attention for an example and save the plot in save_dir with .pdf as its filename. @@ -109,6 +114,7 @@ def plot_attention(attention, hypo_sent, utt_id, save_dir): plt.savefig(filename, bbox_inches='tight') plt.close() + def edit_distance(ref, hyp): """This function is to calculate the edit distance of reference sentence and the hypothesis sentence using dynamic programming, and also backtrace to get @@ -151,8 +157,10 @@ def edit_distance(ref, hyp): while True: if i == 0 and j == 0: break - elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] and \ - ref[i - 1] == hyp[j - 1]: + elif ( + i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] and + ref[i - 1] == hyp[j - 1] + ): steps.append('corr') i, j = i - 1, j - 1 elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] + 1: @@ -168,12 +176,14 @@ def edit_distance(ref, hyp): i = i - 1 steps = steps[::-1] - counter = Counter({'words': len(ref), 'corr': 0, 'sub': 0, 'ins': 0, - 'del': 0}) + counter = Counter( + {'words': len(ref), 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0} + ) counter.update(steps) return dist, steps, counter + def aligned_print(ref, hyp, steps): """This funcition is to print the result of comparing reference and hypothesis sentences in an aligned way. @@ -199,7 +209,7 @@ def aligned_print(ref, hyp, steps): for i in range(len(steps)): delim = ' ' if i < len(steps) - 1 else '\n' if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') + ref_idx = i - steps[:i].count('ins') hyp_idx = i - steps[:i].count('del') if len(ref[ref_idx]) < len(hyp[hyp_idx]): out_str += ref[ref_idx] + \ @@ -218,7 +228,7 @@ def aligned_print(ref, hyp, steps): for i in range(len(steps)): delim = ' ' if i < len(steps) - 1 else '\n' if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') + ref_idx = i - steps[:i].count('ins') hyp_idx = i - steps[:i].count('del') if len(ref[ref_idx]) > len(hyp[hyp_idx]): out_str += hyp[hyp_idx] + \ @@ -237,7 +247,7 @@ def aligned_print(ref, hyp, steps): for i in range(len(steps)): delim = ' ' if i < len(steps) - 1 else '\n' if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') + ref_idx = i - steps[:i].count('ins') hyp_idx = i - steps[:i].count('del') if len(ref[ref_idx]) > len(hyp[hyp_idx]): out_str += 'S' + ' ' * (len(ref[ref_idx]) - 1) + delim @@ -259,58 +269,3 @@ def aligned_print(ref, hyp, steps): out_str += '\n' return out_str - -def lexical_prefix_tree( - word_dict: TokenDictionary, - subword_dict: TokenDictionary, - subword_tokenizer: Callable[[str], List[str]] = None -): - """Build a lexical prefix tree for words. - - Args: - word_dict: an instance of :class:`fairseq.data.TokenDictionary`. - subword_dict: an instance of :class:`fairseq.data.TokenDictionary`. - subword_tokenizer (callable): a function that takes a word string as its - only one argument, and returns a list of subwords as a result of - tokenization. - - Return: - root (Node): the root of the prefix tree, where each node has the fields: - ('children': Dict[int,Node], 'word_idx': int, 'word_set': Tuple[int]). - 'children' is subword_idx -> node, and 'word_set' is (first-1, last), - where [first, last] is the range of the word indexes (inclusive) in - the word dictionary who share the same prefix at that node. - We assume words in the word dictionary are in lexical order. - """ - - class Node(object): - def __init__(self, children={}, word_idx=-1, word_set=None): - self.children = children - self.word_idx = word_idx - self.word_set = word_set - - special_symbols = [word_dict.pad(), word_dict.eos(), word_dict.unk()] - assert 0 in special_symbols # to ensure widx - 1 >= 0 - root = Node({}, -1, None) - for widx in range(len(word_dict)): - if widx not in special_symbols: # skip , , - # tokenize a word into a list of subwords - subwords = subword_tokenizer(word_dict[widx]) \ - if subword_tokenizer is not None else list(word_dict[widx]) - if any(subword_dict.index(s) == subword_dict.unk() for s in subwords): - # skip words containing any unknown subwords - continue - children = root.children - for i, s in enumerate(subwords): - sidx = subword_dict.index(s) - if sidx not in children: # make a new node - children[sidx] = Node({}, -1, (widx - 1, widx)) - else: - children[sidx].word_set = ( - min(children[sidx].word_set[0], widx - 1), - max(children[sidx].word_set[1], widx) - ) - if i == len(subwords) - 1: # if word end, set word_idx - children[sidx].word_idx = widx - children = children[sidx].children # move to children - return root diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index 213fe9d4a..3d1f826dd 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import re +import sys from collections import Counter, OrderedDict @@ -41,8 +42,10 @@ def parse_wer_output_filter(self, wer_output_filter): assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: - print('Unsupported pattern: "{}", ignored'.format(line), - file=sys.stderr) + print( + 'Unsupported pattern: "{}", ignored'.format(line), + file=sys.stderr, + ) def add_prediction(self, utt_id, pred, bpe_symbol=None): if not isinstance(utt_id, str): @@ -50,12 +53,12 @@ def add_prediction(self, utt_id, pred, bpe_symbol=None): if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) - assert not utt_id in self.char_results, \ + assert utt_id not in self.char_results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.char_results[utt_id] = pred + '\n' pred_words = self.dict.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) - assert not utt_id in self.results, \ + assert utt_id not in self.results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.results[utt_id] = pred_words + '\n' @@ -76,8 +79,9 @@ def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): pred = ' '.join([x for x in pred_list if x not in non_lang_syms]) # char level counts - _, _, counter = speech_utils.edit_distance(ref.strip().split(), - pred.strip().split()) + _, _, counter = speech_utils.edit_distance( + ref.strip().split(), pred.strip().split(), + ) self.char_counter += counter # word level counts @@ -90,18 +94,21 @@ def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): pred_words = re.sub(pattern, repl, pred_words) ref_word_list, pred_word_list = ref_words.split(), pred_words.split() - _, steps, counter = speech_utils.edit_distance(ref_word_list, - pred_word_list) + _, steps, counter = speech_utils.edit_distance( + ref_word_list, pred_word_list, + ) self.word_counter += counter - assert not utt_id in self.aligned_results, \ + assert utt_id not in self.aligned_results, \ 'Duplicated utterance id detected: {}'.format(utt_id) - self.aligned_results[utt_id] = speech_utils.aligned_print(ref_word_list, - pred_word_list, steps) + self.aligned_results[utt_id] = speech_utils.aligned_print( + ref_word_list, pred_word_list, steps, + ) def cer(self): assert self.char_counter['words'] > 0 - cer = float(self.char_counter['sub'] + self.char_counter['ins'] + \ - self.char_counter['del']) / self.char_counter['words'] * 100 + cer = float( + self.char_counter['sub'] + self.char_counter['ins'] + self.char_counter['del'] + ) / self.char_counter['words'] * 100 sub = float(self.char_counter['sub']) / self.char_counter['words'] * 100 ins = float(self.char_counter['ins']) / self.char_counter['words'] * 100 dlt = float(self.char_counter['del']) / self.char_counter['words'] * 100 @@ -109,8 +116,9 @@ def cer(self): def wer(self): assert self.word_counter['words'] > 0 - wer = float(self.word_counter['sub'] + self.word_counter['ins'] + \ - self.word_counter['del']) / self.word_counter['words'] * 100 + wer = float( + self.word_counter['sub'] + self.word_counter['ins'] + self.word_counter['del'] + ) / self.word_counter['words'] * 100 sub = float(self.word_counter['sub']) / self.word_counter['words'] * 100 ins = float(self.word_counter['ins']) / self.word_counter['words'] * 100 dlt = float(self.word_counter['del']) / self.word_counter['words'] * 100 @@ -175,4 +183,3 @@ def print_aligned_results(self): for utt_id in self.aligned_results: res += utt_id + '\n' + self.aligned_results[utt_id] return res - diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index cd854cd37..084a418ab 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -205,7 +205,7 @@ if [ ${stage} -le 7 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ @@ -237,7 +237,7 @@ if [ ${stage} -le 8 ]; then feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --task speech_recognition --user-dir espresso --max-tokens 15000 --max-sentences 24 \ + --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --max-sentences 24 \ --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ diff --git a/examples/asr_swbd/local/prepare_ctm.py b/examples/asr_swbd/local/prepare_ctm.py index b1119220d..8183bc27e 100755 --- a/examples/asr_swbd/local/prepare_ctm.py +++ b/examples/asr_swbd/local/prepare_ctm.py @@ -10,9 +10,7 @@ # The format of output is like "en_4156 A start_time duration oh", and so on import argparse -import math import re -import sys def get_parser(): @@ -36,7 +34,7 @@ def main(args): split_content = [] # store ctm results for i, line in enumerate(content): elements = line.strip().split() - + # The first field contains the information of the utterance utt_info = elements[0] infos = re.split('[-_]', utt_info) @@ -58,24 +56,24 @@ def main(args): duration = end_time - start_time_cur split_content.append( ' '.join([utt_id, channel, str(round(start_time_cur, 2)), - str(round(duration, 2)), word]) + str(round(duration, 2)), word]) ) else: - duration = time_step + duration = time_step split_content.append( ' '.join([utt_id, channel, str(round(start_time_cur, 2)), - str(round(duration, 2)), word]) + str(round(duration, 2)), word]) ) if j == 0: split_content.append( ' '.join([utt_id, channel, str(round(start_time, 2)), - str(round(time_diff, 2)), '[noise]']) + str(round(time_diff, 2)), '[noise]']) ) with open(args.ctm_result, 'w', encoding='utf-8') as f: for line in split_content: f.write(line + '\n') - + if __name__ == '__main__': parser = get_parser() diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 634952d0e..8ded8cdd9 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -246,7 +246,7 @@ if [ $stage -le 6 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -281,7 +281,7 @@ if [ $stage -le 7 ]; then # only score train_dev with built-in scorer text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --task speech_recognition --user-dir espresso --max-tokens 24000 --max-sentences 48 \ + --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ --num-shards 1 --shard-id 0 --test-feat-files ${dumpdir}/$dataset/delta${do_delta}/feats.scp $text_opt \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 363d7afcd..1b2d94ccb 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -264,7 +264,7 @@ if [ ${stage} -le 8 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -307,7 +307,7 @@ if [ ${stage} -le 9 ]; then fi text=data/$dataset/token_text CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ - --task speech_recognition --user-dir espresso --max-tokens 20000 --max-sentences 32 \ + --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --max-sentences 32 \ --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ diff --git a/setup.py b/setup.py index 195429803..a587e15b9 100644 --- a/setup.py +++ b/setup.py @@ -186,6 +186,7 @@ def do_setup(package_data): 'dataclasses; python_version<"3.7"', "hydra-core<1.1", "omegaconf<2.1", + "kaldi_io", 'numpy<1.20.0; python_version<"3.7"', 'numpy; python_version>="3.7"', "regex", diff --git a/tests/espresso/test_speech_dataset.py b/tests/espresso/test_speech_dataset.py index eb48acc4c..64be72531 100644 --- a/tests/espresso/test_speech_dataset.py +++ b/tests/espresso/test_speech_dataset.py @@ -11,11 +11,11 @@ import torch from espresso.data import ( + AsrDictionary, + AsrTextDataset, ScpCachedDataset, ScpInMemoryDataset, SpeechDataset, - TokenDictionary, - TokenTextDataset, ) try: @@ -29,12 +29,12 @@ class TestSpeechDataset(unittest.TestCase): @staticmethod def make_dictionary(): """construct dictionary.""" - d = TokenDictionary() + d = AsrDictionary() alphabet = string.ascii_lowercase for token in alphabet: d.add_symbol(token) d.add_symbol('') - d.finalize(padding_factor=1) # don't add extra padding symbols + d.finalize(padding_factor=1) # don't add extra padding symbols d.space_index = d.indices.get('', -1) return d @@ -43,8 +43,9 @@ def generate_feats(test_dir, num=10, seed=0): """generate feature matrices.""" feats = {} np.random.seed(seed) - with open(os.path.join(test_dir, 'feats.scp'), 'w', - encoding='utf-8') as f: + with open( + os.path.join(test_dir, 'feats.scp'), 'w', encoding='utf-8', + ) as f: for i in range(num): utt_id = 'utt_id_' + str(i) ark_file = os.path.join(test_dir, 'mat_' + str(i) + '.ark') @@ -65,13 +66,15 @@ def generate_text_tokens(test_dir, num=10, seed=0): vocab = list(alphabet) vocab.append(space) np.random.seed(seed) - with open(os.path.join(test_dir, 'text_tokens'), 'w', - encoding='utf-8') as f: + with open( + os.path.join(test_dir, 'text_tokens'), 'w', encoding='utf-8', + ) as f: for i in np.random.permutation(range(num)): utt_id = 'utt_id_' + str(i) length = np.random.randint(10, 100) - tokens = [vocab[np.random.randint(0, len(vocab))] \ - for _ in range(length)] + tokens = [ + vocab[np.random.randint(0, len(vocab))] for _ in range(length) + ] if tokens[0] == space: tokens[0] = vocab[np.random.randint(0, len(vocab) - 1)] if tokens[-1] == space: @@ -88,15 +91,18 @@ def setUp(self): self.batch_size = 8 self.cache_size = 16 self.dict = self.make_dictionary() - self.expected_feats = self.generate_feats(self.test_dir, - num=self.num_audios, seed=0) - self.expected_tokens = self.generate_text_tokens(self.test_dir, - num=self.num_transripts, seed=1) + self.expected_feats = self.generate_feats( + self.test_dir, num=self.num_audios, seed=0, + ) + self.expected_tokens = self.generate_text_tokens( + self.test_dir, num=self.num_transripts, seed=1, + ) self.cuda = torch.cuda.is_available() - def _speech_dataset_helper(self, all_in_memory=False, - ordered_prefetch=False): + def _speech_dataset_helper( + self, all_in_memory=False, ordered_prefetch=False, + ): if not all_in_memory: src_dataset = ScpCachedDataset( path=os.path.join(self.test_dir, 'feats.scp'), @@ -107,7 +113,7 @@ def _speech_dataset_helper(self, all_in_memory=False, src_dataset = ScpInMemoryDataset( path=os.path.join(self.test_dir, 'feats.scp') ) - tgt_dataset = TokenTextDataset( + tgt_dataset = AsrTextDataset( path=os.path.join(self.test_dir, 'text_tokens'), dictionary=self.dict, ) @@ -171,12 +177,15 @@ def test_speech_dataset_all_in_memory(self): def assertTensorEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") - if (t1.dtype == torch.short or t1.dtype == torch.int or \ - t1.dtype == torch.long) and (t2.dtype == torch.short or \ - t2.dtype == torch.int or t2.dtype == torch.long): + if ( + (t1.dtype == torch.short or t1.dtype == torch.int or + t1.dtype == torch.long) and + (t2.dtype == torch.short or t2.dtype == torch.int or + t2.dtype == torch.long) + ): self.assertEqual(t1.ne(t2).long().sum(), 0) else: - self.assertEqual(t1.allclose(t2,rtol=1e-05, atol=1e-08), True) + self.assertEqual(t1.allclose(t2, rtol=1e-05, atol=1e-08), True) if __name__ == "__main__": diff --git a/tests/espresso/test_speech_utils.py b/tests/espresso/test_speech_utils.py index 9a3c0b2a9..e8213d997 100644 --- a/tests/espresso/test_speech_utils.py +++ b/tests/espresso/test_speech_utils.py @@ -10,7 +10,7 @@ import torch -from espresso.data import TokenDictionary +from espresso.data import AsrDictionary import espresso.tools.utils as utils @@ -21,13 +21,13 @@ class TestSpeechUtils(unittest.TestCase): def make_dictionary(vocab, non_lang_syms=[]): """construct dictionary.""" assert isinstance(vocab, list) and isinstance(non_lang_syms, list) - d = TokenDictionary() + d = AsrDictionary() for token in vocab: d.add_symbol(token) d.add_symbol('') for token in non_lang_syms: d.add_symbol(token) - d.finalize(padding_factor=1) # don't add extra padding symbols + d.finalize(padding_factor=1) # don't add extra padding symbols d.space_index = d.indices.get('', -1) return d @@ -60,32 +60,37 @@ def setUp(self): self.oovs = list(string.ascii_uppercase) self.non_lang_syms = ['', '', ''] self.num_sentences = 100 - self.dict = self.make_dictionary(self.vocab, + self.dict = self.make_dictionary( + self.vocab, non_lang_syms=self.non_lang_syms, ) - self.text = [self.generate_text(self.vocab, self.oovs, - self.non_lang_syms, seed=i) for i in range(self.num_sentences)] + self.text = [self.generate_text( + self.vocab, self.oovs, self.non_lang_syms, seed=i, + ) for i in range(self.num_sentences)] def test_speech_tokenizer(self): for i, sent in enumerate(self.text): print('test sentence {}:'.format(i)) print(sent) - tokens = utils.tokenize(sent, \ - space=self.dict.space_word, non_lang_syms=self.non_lang_syms) + tokens = utils.tokenize( + sent, space=self.dict.space_word, + non_lang_syms=self.non_lang_syms, + ) # test :func:`~speech_tools.utils.tokenize` with - # :func:`~TokenDictionary.encode_line` - tensor = self.dict.encode_line(tokens, add_if_not_exist=False, - append_eos=True) + # :func:`~AsrDictionary.encode_line` + tensor = self.dict.encode_line( + tokens, add_if_not_exist=False, append_eos=True, + ) reconstructed_tokens = self.dict.string(tensor) expected_tokens = ' '.join( - [token if self.dict.index(token) != self.dict.unk() else \ + [token if self.dict.index(token) != self.dict.unk() else self.dict.unk_word for token in tokens.split(' ')] ) self.assertEqual(reconstructed_tokens, expected_tokens) # test :func:`~speech_tools.utils.tokenize` with - # :func:`~TokenDictionary.tokens_to_sentence` + # :func:`~AsrDictionary.tokens_to_sentence` reconstructed_sent = self.dict.tokens_to_sentence(tokens) expected_sent = [] words = sent.split(' ') @@ -129,12 +134,12 @@ def test_sequence_mask(self): [1, 0, 0, 0], [1, 1, 1, 1], [0, 0, 0, 0], - [1, 1, 1, 0]]).byte() + [1, 1, 1, 0]]).bool() expected_mask2 = torch.tensor([ [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [0, 0, 0, 0, 0], - [1, 1, 1, 0, 0]]).byte() + [1, 1, 1, 0, 0]]).bool() generated_mask = utils.sequence_mask(seq_len) generated_mask2 = utils.sequence_mask(seq_len, max_len=5) @@ -155,56 +160,72 @@ def test_convert_padding_direction(self): [0.0, 0.0, 0.0, 1.5]]).unsqueeze(-1).expand(-1, -1, 10) seq_len = torch.tensor([3, 2, 4, 1]).int() - t1_to_t2 = utils.convert_padding_direction(t1, seq_len, - right_to_left=True) + t1_to_t2 = utils.convert_padding_direction( + t1, seq_len, right_to_left=True, + ) self.assertTensorEqual(t1_to_t2, t2) - t2_to_t1 = utils.convert_padding_direction(t2, seq_len, - left_to_right=True) + t2_to_t1 = utils.convert_padding_direction( + t2, seq_len, left_to_right=True, + ) self.assertTensorEqual(t2_to_t1, t1) def test_edit_distance(self): ref, hyp = [], [] dist, steps, counter = utils.edit_distance(ref, hyp) - self.assertEqual(counter, - Counter({'words': 0, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0})) + self.assertEqual( + counter, + Counter({'words': 0, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0}), + ) self.assertEqual(steps, []) ref, hyp = ['a', 'b', 'c'], [] dist, steps, counter = utils.edit_distance(ref, hyp) - self.assertEqual(counter, - Counter({'words': 3, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 3})) + self.assertEqual( + counter, + Counter({'words': 3, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 3}), + ) self.assertEqual(steps, ['del', 'del', 'del']) ref, hyp = ['a', 'b', 'c'], ['a', 'b', 'c'] dist, steps, counter = utils.edit_distance(ref, hyp) - self.assertEqual(counter, - Counter({'words': 3, 'corr': 3, 'sub': 0, 'ins': 0, 'del': 0})) + self.assertEqual( + counter, + Counter({'words': 3, 'corr': 3, 'sub': 0, 'ins': 0, 'del': 0}), + ) self.assertEqual(steps, ['corr', 'corr', 'corr']) ref, hyp = ['a', 'b', 'c'], ['d', 'b', 'c', 'e', 'f'] dist, steps, counter = utils.edit_distance(ref, hyp) - self.assertEqual(counter, - Counter({'words': 3, 'corr': 2, 'sub': 1, 'ins': 2, 'del': 0})) + self.assertEqual( + counter, + Counter({'words': 3, 'corr': 2, 'sub': 1, 'ins': 2, 'del': 0}), + ) self.assertEqual(steps, ['sub', 'corr', 'corr', 'ins', 'ins']) ref, hyp = ['b', 'c', 'd', 'e', 'f', 'h'], \ ['d', 'b', 'c', 'e', 'f', 'g'] dist, steps, counter = utils.edit_distance(ref, hyp) - self.assertEqual(counter, - Counter({'words': 6, 'corr': 4, 'sub': 1, 'ins': 1, 'del': 1})) - self.assertEqual(steps, - ['ins', 'corr', 'corr', 'del', 'corr', 'corr', 'sub']) + self.assertEqual( + counter, + Counter({'words': 6, 'corr': 4, 'sub': 1, 'ins': 1, 'del': 1}), + ) + self.assertEqual( + steps, + ['ins', 'corr', 'corr', 'del', 'corr', 'corr', 'sub'], + ) def assertTensorEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") - if (t1.dtype == torch.short or t1.dtype == torch.int or \ - t1.dtype == torch.long or t1.dtype == torch.uint8) and \ - (t2.dtype == torch.short or t2.dtype == torch.int or \ - t2.dtype == torch.long or t2.dtype == torch.uint8): + if (t1.dtype == torch.short or t1.dtype == torch.int or + t1.dtype == torch.long or t1.dtype == torch.uint8 or + t1.dtype == torch.bool) and \ + (t2.dtype == torch.short or t2.dtype == torch.int or + t2.dtype == torch.long or t2.dtype == torch.uint8 or + t2.dtype == torch.bool): self.assertEqual(t1.ne(t2).long().sum(), 0) else: - self.assertEqual(t1.allclose(t2,rtol=1e-05, atol=1e-08), True) + self.assertEqual(t1.allclose(t2, rtol=1e-05, atol=1e-08), True) if __name__ == "__main__": From 732ab08018937a2371e0b764b10d7e8dff81d5ac Mon Sep 17 00:00:00 2001 From: Tongfei Chen Date: Fri, 20 Dec 2019 18:43:14 -0500 Subject: [PATCH 053/119] scheduled sampling rate scheduler --- .../scheduled_sampling_rate_scheduler.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 espresso/tools/scheduled_sampling_rate_scheduler.py diff --git a/espresso/tools/scheduled_sampling_rate_scheduler.py b/espresso/tools/scheduled_sampling_rate_scheduler.py new file mode 100644 index 000000000..4eca7ebe8 --- /dev/null +++ b/espresso/tools/scheduled_sampling_rate_scheduler.py @@ -0,0 +1,39 @@ +# Copyright (c) Tongfei Chen, Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +from fairseq.options import eval_str_list + + +class ScheduledSamplingRateScheduler: + + def __init__(self, args): + self.args = args + + @staticmethod + def add_args(parser: argparse.ArgumentParser): + parser.add_argument('--scheduled-sampling-probs', type=lambda p: eval_str_list(p), + metavar='P_1,P_2,...,P_N', default=1.0, + help='scheduled sampling probabilities of sampling the truth ' + 'labels for N epochs starting from --start-schedule-sampling-epoch; ' + 'all later epochs using P_N') + parser.add_argument('--start-scheduled-sampling-epoch', type=int, + metavar='N', default=1, + help='start scheduled sampling from the specified epoch') + + def step(self, epoch: int) -> float: + if ( + (len(self.args.scheduled_sampling_probs) > 1 or + self.args.scheduled_sampling_probs[0] < 1.0) and + epoch >= self.args.start_scheduled_sampling_epoch + ): + ss_prob = self.args.scheduled_sampling_probs[ + min(epoch - self.args.start_scheduled_sampling_epoch, + len(self.args.scheduled_sampling_probs) - 1) + ] + return ss_prob + else: + return 1.0 From ad0157ec1a6eea360fbfae37be5d05a52378b4bd Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 20 Dec 2019 21:47:56 -0500 Subject: [PATCH 054/119] decouple scheduled sampling rate scheduler; rename all appearances of "dict" variables to "dictionary" as "dict" is a reserved keyword in Python; affect swbd results due to some numerical issue of PyTorch --- espresso/criterions/cross_entropy_with_wer.py | 81 ++++-------------- .../label_smoothed_cross_entropy_with_wer.py | 85 ++++--------------- espresso/data/speech_dataset.py | 8 +- espresso/models/speech_lstm.py | 59 ++++++++++++- espresso/speech_recognize.py | 12 +-- espresso/tasks/speech_recognition.py | 31 ++++--- .../scheduled_sampling_rate_scheduler.py | 48 ++++++----- espresso/tools/wer.py | 12 +-- examples/asr_swbd/run.sh | 2 +- tests/espresso/test_speech_dataset.py | 8 +- tests/espresso/test_speech_utils.py | 16 ++-- 11 files changed, 161 insertions(+), 201 deletions(-) diff --git a/espresso/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py index 91a4d7b45..78889e3d2 100644 --- a/espresso/criterions/cross_entropy_with_wer.py +++ b/espresso/criterions/cross_entropy_with_wer.py @@ -10,7 +10,6 @@ from fairseq import utils from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder -from fairseq.options import eval_str_list from fairseq.criterions import register_criterion from fairseq.criterions.cross_entropy import CrossEntropyCriterion @@ -24,8 +23,8 @@ class CrossEntropyWithWERCriterion(CrossEntropyCriterion): def __init__(self, args, task): super().__init__(args, task) - dict = task.target_dictionary - self.scorer = wer.Scorer(dict, wer_output_filter=task.args.wer_output_filter) + dictionary = task.target_dictionary + self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 @@ -39,14 +38,6 @@ def add_args(parser): metavar='N', dest='print_interval', default=500, help='print a training sample (reference + ' 'prediction) every this number of updates') - parser.add_argument('--scheduled-sampling-probs', type=lambda p: eval_str_list(p), - metavar='P_1,P_2,...,P_N', default=1.0, - help='schedule sampling probabilities of sampling the truth ' - 'labels for N epochs starting from --start-schedule-sampling-epoch; ' - 'all later epochs using P_N') - parser.add_argument('--start-scheduled-sampling-epoch', type=int, - metavar='N', default=1, - help='start schedule sampling from the specified epoch') # fmt: on def forward(self, model, sample, reduce=True): @@ -59,51 +50,11 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - dict = self.scorer.dict + dictionary = self.scorer.dictionary if model.training: - if ( - (len(self.args.scheduled_sampling_probs) > 1 or - self.args.scheduled_sampling_probs[0] < 1.0) and - self.epoch >= self.args.start_scheduled_sampling_epoch - ): - # scheduled sampling - ss_prob = self.args.scheduled_sampling_probs[ - min(self.epoch - self.args.start_scheduled_sampling_epoch, - len(self.args.scheduled_sampling_probs) - 1) - ] - assert isinstance(model.decoder, FairseqIncrementalDecoder) - incremental_states = {} - encoder_input = { - k: v for k, v in sample['net_input'].items() - if k != 'prev_output_tokens' - } - encoder_out = model.encoder(**encoder_input) - target = sample['target'] - tokens = sample['net_input']['prev_output_tokens'] - lprobs = [] - pred = None - for step in range(target.size(1)): - if step > 0: - sampling_mask = torch.rand( - [target.size(0), 1], - device=target.device, - ).lt(ss_prob) - feed_tokens = torch.where( - sampling_mask, tokens[:, step:step + 1], pred, - ) - else: - feed_tokens = tokens[:, step:step + 1] - log_probs, _ = self._decode( - feed_tokens, model, encoder_out, incremental_states, - ) - pred = log_probs.argmax(-1, keepdim=True) - lprobs.append(log_probs) - lprobs = torch.stack(lprobs, dim=1) - else: - # normal training - net_output = model(**sample['net_input']) - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) + net_output = model(**sample['net_input'], epoch=self.epoch) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) else: assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} @@ -117,13 +68,13 @@ def forward(self, model, sample, reduce=True): # target, and the length of encoder_out if possible maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) - tokens[:, 0] = dict.eos() + tokens[:, 0] = dictionary.eos() lprobs = [] attn = [] if getattr(model.decoder, 'need_attn', False) else None dummy_log_probs = encoder_out['encoder_out'][0].new_full( - [target.size(0), len(dict)], -np.log(len(dict))) + [target.size(0), len(dictionary)], -np.log(len(dictionary))) for step in range(maxlen + 1): # one extra step for EOS marker - is_eos = tokens[:, step].eq(dict.eos()) + is_eos = tokens[:, step].eq(dictionary.eos()) # if all predictions are finished (i.e., ended with eos), # pad lprobs to target length with dummy log probs, # truncate tokens up to this step and break @@ -140,7 +91,7 @@ def forward(self, model, sample, reduce=True): # make log_probs uniform if the previous output token is EOS # and add consecutive EOS to the end of prediction log_probs[is_eos, :] = -np.log(log_probs.size(1)) - tokens[is_eos, step + 1] = dict.eos() + tokens[is_eos, step + 1] = dictionary.eos() if step < target.size(1): lprobs.append(log_probs) if getattr(model.decoder, 'need_attn', False): @@ -167,13 +118,13 @@ def forward(self, model, sample, reduce=True): for i in range(target.size(0)): utt_id = sample['utt_id'][i] id = sample['id'].data[i].item() - # ref_tokens = dict.string(target.data[i]) + # ref_tokens = dictionary.string(target.data[i]) # if it is a dummy batch (e.g., a "padding" batch in a sharded # dataset), id might exceeds the dataset size; in that case we # just skip it if id < len(self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dict.string(pred.data[i]) + pred_tokens = dictionary.string(pred.data[i]) self.scorer.add_evaluation( utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, @@ -184,12 +135,12 @@ def forward(self, model, sample, reduce=True): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - # ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) + # ref_one = dictionary.tokens_to_sentence(dictionary.string(target.data[i])) ref_one = self.train_tgt_dataset.get_original_text( - id, dict, bpe_symbol=self.args.remove_bpe, + id, dictionary, bpe_symbol=self.args.remove_bpe, ) - pred_one = dict.tokens_to_sentence( - dict.string(pred.data[i][:length]), + pred_one = dictionary.tokens_to_sentence( + dictionary.string(pred.data[i][:length]), bpe_symbol=self.args.remove_bpe, ) print('| sample REF: ' + ref_one) diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index 1fdc44785..ed87bfdf0 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -9,7 +9,6 @@ from fairseq import utils from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder -from fairseq.options import eval_str_list from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion @@ -55,17 +54,17 @@ class LabelSmoothedCrossEntropyWithWERCriterion(LabelSmoothedCrossEntropyCriteri def __init__(self, args, task): super().__init__(args, task) - dict = task.target_dictionary - self.scorer = wer.Scorer(dict, wer_output_filter=task.args.wer_output_filter) + dictionary = task.target_dictionary + self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 self.epoch = 0 self.unigram_tensor = None if args.smoothing_type == 'unigram': - self.unigram_tensor = torch.cuda.FloatTensor(dict.count).unsqueeze(-1) \ + self.unigram_tensor = torch.cuda.FloatTensor(dictionary.count).unsqueeze(-1) \ if torch.cuda.is_available() and not args.cpu \ - else torch.FloatTensor(dict.count).unsqueeze(-1) + else torch.FloatTensor(dictionary.count).unsqueeze(-1) self.unigram_tensor += args.unigram_pseudo_count # for further backoff self.unigram_tensor.div_(self.unigram_tensor.sum()) @@ -84,14 +83,6 @@ def add_args(parser): parser.add_argument('--unigram-pseudo-count', type=float, default=1.0, metavar='C', help='pseudo count for unigram label ' 'smoothing. Only relevant if --smoothing-type=unigram') - parser.add_argument('--scheduled-sampling-probs', type=lambda p: eval_str_list(p), - metavar='P_1,P_2,...,P_N', default=1.0, - help='scheduled sampling probabilities of sampling the truth ' - 'labels for N epochs starting from --start-schedule-sampling-epoch; ' - 'all later epochs using P_N') - parser.add_argument('--start-scheduled-sampling-epoch', type=int, - metavar='N', default=1, - help='start scheduled sampling from the specified epoch') # fmt: on def forward(self, model, sample, reduce=True): @@ -104,51 +95,11 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - dict = self.scorer.dict + dictionary = self.scorer.dictionary if model.training: - if ( - (len(self.args.scheduled_sampling_probs) > 1 or - self.args.scheduled_sampling_probs[0] < 1.0) and - self.epoch >= self.args.start_scheduled_sampling_epoch - ): - # scheduled sampling - ss_prob = self.args.scheduled_sampling_probs[ - min(self.epoch - self.args.start_scheduled_sampling_epoch, - len(self.args.scheduled_sampling_probs) - 1) - ] - assert isinstance(model.decoder, FairseqIncrementalDecoder) - incremental_states = {} - encoder_input = { - k: v for k, v in sample['net_input'].items() - if k != 'prev_output_tokens' - } - encoder_out = model.encoder(**encoder_input) - target = sample['target'] - tokens = sample['net_input']['prev_output_tokens'] - lprobs = [] - pred = None - for step in range(target.size(1)): - if step > 0: - sampling_mask = torch.rand( - [target.size(0), 1], - device=target.device, - ).lt(ss_prob) - feed_tokens = torch.where( - sampling_mask, tokens[:, step:step + 1], pred, - ) - else: - feed_tokens = tokens[:, step:step + 1] - log_probs, _ = self._decode( - feed_tokens, model, encoder_out, incremental_states, - ) - pred = log_probs.argmax(-1, keepdim=True) - lprobs.append(log_probs) - lprobs = torch.stack(lprobs, dim=1) - else: - # normal training - net_output = model(**sample['net_input']) - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) + net_output = model(**sample['net_input'], epoch=self.epoch) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) else: assert isinstance(model.decoder, FairseqIncrementalDecoder) incremental_states = {} @@ -162,13 +113,13 @@ def forward(self, model, sample, reduce=True): # target, and the length of encoder_out if possible maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) - tokens[:, 0] = dict.eos() + tokens[:, 0] = dictionary.eos() lprobs = [] attn = [] if getattr(model.decoder, 'need_attn', False) else None dummy_log_probs = encoder_out['encoder_out'][0].new_full( - [target.size(0), len(dict)], -np.log(len(dict))) + [target.size(0), len(dictionary)], -np.log(len(dictionary))) for step in range(maxlen + 1): # one extra step for EOS marker - is_eos = tokens[:, step].eq(dict.eos()) + is_eos = tokens[:, step].eq(dictionary.eos()) # if all predictions are finished (i.e., ended with eos), # pad lprobs to target length with dummy log probs, # truncate tokens up to this step and break @@ -185,7 +136,7 @@ def forward(self, model, sample, reduce=True): # make log_probs uniform if the previous output token is EOS # and add consecutive EOS to the end of prediction log_probs[is_eos, :] = -np.log(log_probs.size(1)) - tokens[is_eos, step + 1] = dict.eos() + tokens[is_eos, step + 1] = dictionary.eos() if step < target.size(1): lprobs.append(log_probs) if getattr(model.decoder, 'need_attn', False): @@ -212,13 +163,13 @@ def forward(self, model, sample, reduce=True): for i in range(target.size(0)): utt_id = sample['utt_id'][i] id = sample['id'].data[i].item() - # ref_tokens = dict.string(target.data[i]) + # ref_tokens = dictionary.string(target.data[i]) # if it is a dummy batch (e.g., a "padding" batch in a sharded # dataset), id might exceeds the dataset size; in that case we # just skip it if id < len(self.valid_tgt_dataset): ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dict.string(pred.data[i]) + pred_tokens = dictionary.string(pred.data[i]) self.scorer.add_evaluation( utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, @@ -229,12 +180,12 @@ def forward(self, model, sample, reduce=True): i = np.random.randint(0, len(sample['id'])) id = sample['id'].data[i].item() length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - # ref_one = dict.tokens_to_sentence(dict.string(target.data[i])) + # ref_one = dictionary.tokens_to_sentence(dictionary.string(target.data[i])) ref_one = self.train_tgt_dataset.get_original_text( - id, dict, bpe_symbol=self.args.remove_bpe, + id, dictionary, bpe_symbol=self.args.remove_bpe, ) - pred_one = dict.tokens_to_sentence( - dict.string(pred.data[i][:length]), + pred_one = dictionary.tokens_to_sentence( + dictionary.string(pred.data[i][:length]), bpe_symbol=self.args.remove_bpe, ) print('| sample REF: ' + ref_one) diff --git a/espresso/data/speech_dataset.py b/espresso/data/speech_dataset.py index 1b6b1de9c..8bd6e82c3 100644 --- a/espresso/data/speech_dataset.py +++ b/espresso/data/speech_dataset.py @@ -85,7 +85,7 @@ class SpeechDataset(FairseqDataset): src_sizes (List[int]): source sentence lengths tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt_sizes (List[int], optional): target sentence lengths - dict (~fairseq.data.Dictionary, optional): target vocabulary + dictionary (~fairseq.data.Dictionary, optional): target vocabulary left_pad_source (bool, optional): pad source tensors on the left side (default: True). left_pad_target (bool, optional): pad target tensors on the left side @@ -102,7 +102,7 @@ class SpeechDataset(FairseqDataset): def __init__( self, src, src_sizes, - tgt=None, tgt_sizes=None, dict=None, + tgt=None, tgt_sizes=None, dictionary=None, left_pad_source=False, left_pad_target=False, max_source_positions=1024, max_target_positions=1024, shuffle=True, input_feeding=True, @@ -111,7 +111,7 @@ def __init__( self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None - self.dict = dict + self.dictionary = dictionary self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.max_source_positions = max_source_positions @@ -187,7 +187,7 @@ def collater(self, samples): on the left if *left_pad_target* is ``True``. """ return collate( - samples, pad_idx=self.dict.pad(), eos_idx=self.dict.eos(), + samples, pad_idx=self.dictionary.pad(), eos_idx=self.dictionary.eos(), left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, ) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index ecb5400ec..1168f6945 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -27,6 +27,7 @@ from espresso.modules import speech_attention from espresso.tasks.speech_recognition import SpeechRecognitionEspressoTask +from espresso.tools.scheduled_sampling_rate_scheduler import ScheduledSamplingRateScheduler import espresso.tools.utils as speech_utils @@ -105,6 +106,16 @@ def add_args(parser): help='dropout probability for decoder input embedding') parser.add_argument('--decoder-dropout-out', type=float, metavar='D', help='dropout probability for decoder output') + + # Scheduled sampling options + parser.add_argument('--scheduled-sampling-probs', type=lambda p: options.eval_str_list(p), + metavar='P_1,P_2,...,P_N', default=1.0, + help='scheduled sampling probabilities of sampling the truth ' + 'labels for N epochs starting from --start-schedule-sampling-epoch; ' + 'all later epochs using P_N') + parser.add_argument('--start-scheduled-sampling-epoch', type=int, + metavar='N', default=1, + help='start scheduled sampling from the specified epoch') # fmt: on @classmethod @@ -178,6 +189,10 @@ def eval_str_nested_list_or_tuple(x, type=int): rnn_encoder_input_size = (rnn_encoder_input_size + s - 1) // s rnn_encoder_input_size *= out_channels[-1] + scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler( + args.scheduled_sampling_probs, args.start_scheduled_sampling_epoch, + ) + encoder = SpeechLSTMEncoder( conv_layers_before=conv_layers, input_size=rnn_encoder_input_size, @@ -207,6 +222,7 @@ def eval_str_nested_list_or_tuple(x, type=int): options.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == 'adaptive_loss' else None ), + scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, ) pretrained_lm = None if args.pretrained_lm_checkpoint: @@ -423,7 +439,7 @@ def output_lengths(self, in_lengths): return in_lengths if self.conv_layers_before is None \ else self.conv_layers_before.output_lengths(in_lengths) - def forward(self, src_tokens, src_lengths): + def forward(self, src_tokens, src_lengths, **unused): if self.left_pad: # nn.utils.rnn.pack_padded_sequence requires right-padding; # convert left-padding to right-padding @@ -494,6 +510,7 @@ def __init__( num_layers=1, dropout_in=0.1, dropout_out=0.1, encoder_output_units=0, attn_type=None, attn_dim=0, need_attn=False, residual=False, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, + scheduled_sampling_rate_scheduler=None, ): super().__init__(dictionary) self.dropout_in = dropout_in @@ -545,7 +562,9 @@ def __init__( elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) - def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): + self.scheduled_sampling_rate_scheduler = scheduled_sampling_rate_scheduler + + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape @@ -560,11 +579,45 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, - the decoder's output of shape `(batch, tgt_len, vocab)` - attention weights of shape `(batch, tgt_len, src_len)` """ + if self.scheduled_sampling_rate_scheduler is not None: + epoch = kwargs.get('epoch', 0) + if epoch > 0: + sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) + if sampling_prob < 1.0: # apply scheduled sampling + return self._forward_with_schduled_sampling( + prev_output_tokens, sampling_prob, encoder_out=encoder_out, + incremental_state={}, # use empty dict to preserve forward state + ) + x, attn_scores = self.extract_features( - prev_output_tokens, encoder_out, incremental_state + prev_output_tokens, encoder_out, incremental_state, ) return self.output_layer(x), attn_scores + def _forward_with_schduled_sampling( + self, prev_output_tokens, sampling_prob, encoder_out=None, incremental_state=None, + ): + bsz, seqlen = prev_output_tokens.size() + outs = [] + pred = None + for step in range(seqlen): + if step > 0: + sampling_mask = torch.rand( + [bsz, 1], device=prev_output_tokens.device, + ).lt(sampling_prob) + feed_tokens = torch.where( + sampling_mask, prev_output_tokens[:, step:step + 1], pred, + ) + else: + feed_tokens = prev_output_tokens[:, step:step + 1] # B x 1 + x, _ = self.extract_features(feed_tokens, encoder_out, incremental_state) + x = self.output_layer(x) # B x 1 x V + outs.append(x) + pred = x.argmax(-1) # B x 1 + x = torch.cat(outs, dim=1) # B x T x V + # ignore attention scores + return x, None + def extract_features( self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused, ): diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 2156d19c3..d7cc874e6 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -40,7 +40,7 @@ def main(args): task.load_dataset(args.gen_subset) # Set dictionary - dict = task.target_dictionary + dictionary = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) @@ -63,7 +63,7 @@ def main(args): print('| LM fusion with Multi-level LM') else: models[i] = TensorizedLookaheadLanguageModel( - m, dict, + m, dictionary, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) @@ -110,7 +110,7 @@ def main(args): generator = task.build_generator(args) # Generate and compute WER - scorer = wer.Scorer(dict, wer_output_filter=args.wer_output_filter) + scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) num_sentences = 0 has_target = True with progress_bar.build_progress_bar(args, itr) as t: @@ -146,16 +146,16 @@ def main(args): if has_target: target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) if not args.quiet: - target_sent = dict.tokens_to_sentence( + target_sent = dictionary.tokens_to_sentence( target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, ) print('T-{}\t{}'.format(utt_id, target_sent)) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): - hypo_str = dict.string(hypo['tokens'].int().cpu()) # not removing bpe at this point + hypo_str = dictionary.string(hypo['tokens'].int().cpu()) # not removing bpe at this point if not args.quiet or i == 0: - hypo_sent = dict.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) + hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) if not args.quiet: print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 7258b9844..a23967ebf 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -26,12 +26,15 @@ class SpeechRecognitionEspressoTask(FairseqTask): Transcribe from speech (source) to token text (target). Args: - dict (~fairseq.data.AsrDictionary): dictionary for the output tokens + dictionary (~fairseq.data.AsrDictionary): dictionary for the output tokens + word_dict (~fairseq.data.AsrDictionary): dictionary for the words + (for decoding with word-based LMs) + feat_in_channels (int): input feature channels .. note:: The speech recognition task is compatible with :mod:`speech-train`, - :mod:`speech-recognition` and :mod:`fairseq-interactive`. + :mod:`speech-recognize` and :mod:`fairseq-interactive`. The speech recognition task provides the following additional command-line arguments: @@ -105,9 +108,9 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding """ raise NotImplementedError - def __init__(self, args, dict, word_dict=None): + def __init__(self, args, dictionary, word_dict=None): super().__init__(args) - self.dict = dict + self.dictionary = dictionary self.word_dict = word_dict self.feat_in_channels = args.feat_in_channels torch.backends.cudnn.deterministic = True @@ -130,15 +133,15 @@ def setup_task(cls, args, **kwargs): dict_path = os.path.join(os.path.dirname(args.text_files[0]), 'dict.txt') \ if args.dict is None and args.text_files is not None else args.dict assert dict_path is not None, 'Please specify --dict' - dict = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) - print('| dictionary: {} types'.format(len(dict))) + dictionary = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) + print('| dictionary: {} types'.format(len(dictionary))) if args.word_dict is not None: word_dict = cls.load_dictionary(args.word_dict) print('| word dictionary: {} types'.format(len(word_dict))) - return cls(args, dict, word_dict) + return cls(args, dictionary, word_dict) else: - return cls(args, dict) + return cls(args, dictionary) def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. @@ -177,7 +180,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): src_datasets.append(ScpCachedDataset(feat, ordered_prefetch=True)) print('| {} {} examples'.format(feat, len(src_datasets[-1]))) if text is not None: - tgt_datasets.append(AsrTextDataset(text, self.dict)) + tgt_datasets.append(AsrTextDataset(text, self.dictionary)) print('| {} {} examples'.format(text, len(tgt_datasets[-1]))) if not combine: @@ -204,7 +207,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): self.datasets[split] = SpeechDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, - self.dict, + self.dictionary, left_pad_source=self.args.left_pad_source, left_pad_target=self.args.left_pad_target, max_source_positions=self.args.max_source_positions, @@ -213,11 +216,11 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): # update the counts of and in dictionary with training data if split == 'train': - self.dict.count[self.dict.eos()] = len(tgt_dataset) + self.dictionary.count[self.dictionary.eos()] = len(tgt_dataset) unk_count = 0 for i in range(len(tgt_dataset)): - unk_count += (tgt_dataset[i] == self.dict.unk()).int().sum().item() - self.dict.count[self.dict.unk()] = unk_count + unk_count += (tgt_dataset[i] == self.dictionary.unk()).int().sum().item() + self.dictionary.count[self.dictionary.unk()] = unk_count def build_generator(self, args): if args.score_reference: @@ -261,7 +264,7 @@ def max_positions(self): @property def target_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" - return self.dict + return self.dictionary @property def word_dictionary(self): diff --git a/espresso/tools/scheduled_sampling_rate_scheduler.py b/espresso/tools/scheduled_sampling_rate_scheduler.py index 4eca7ebe8..2b9b029d0 100644 --- a/espresso/tools/scheduled_sampling_rate_scheduler.py +++ b/espresso/tools/scheduled_sampling_rate_scheduler.py @@ -3,37 +3,39 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import argparse +from typing import List -from fairseq.options import eval_str_list +class ScheduledSamplingRateScheduler(object): -class ScheduledSamplingRateScheduler: + def __init__( + self, + scheduled_sampling_probs: List[float] = [1.0], + start_scheduled_sampling_epoch: int = 1, + ): + """ + Args: + scheduled_sampling_probs (List[float]): P_1,P_2,...,P_N. + Scheduled sampling probabilities of sampling the truth labels + for N epochs starting from --start-schedule-sampling-epoch; + all later epochs using P_N. + start_scheduled_sampling_epoch (int): start scheduled sampling from + the specified epoch. + """ - def __init__(self, args): - self.args = args - - @staticmethod - def add_args(parser: argparse.ArgumentParser): - parser.add_argument('--scheduled-sampling-probs', type=lambda p: eval_str_list(p), - metavar='P_1,P_2,...,P_N', default=1.0, - help='scheduled sampling probabilities of sampling the truth ' - 'labels for N epochs starting from --start-schedule-sampling-epoch; ' - 'all later epochs using P_N') - parser.add_argument('--start-scheduled-sampling-epoch', type=int, - metavar='N', default=1, - help='start scheduled sampling from the specified epoch') + self.scheduled_sampling_probs = scheduled_sampling_probs + self.start_scheduled_sampling_epoch = start_scheduled_sampling_epoch def step(self, epoch: int) -> float: if ( - (len(self.args.scheduled_sampling_probs) > 1 or - self.args.scheduled_sampling_probs[0] < 1.0) and - epoch >= self.args.start_scheduled_sampling_epoch + (len(self.scheduled_sampling_probs) > 1 or + self.scheduled_sampling_probs[0] < 1.0) and + epoch >= self.start_scheduled_sampling_epoch ): - ss_prob = self.args.scheduled_sampling_probs[ - min(epoch - self.args.start_scheduled_sampling_epoch, - len(self.args.scheduled_sampling_probs) - 1) + prob = self.scheduled_sampling_probs[ + min(epoch - self.start_scheduled_sampling_epoch, + len(self.scheduled_sampling_probs) - 1) ] - return ss_prob + return prob else: return 1.0 diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index 3d1f826dd..f3c8c25af 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -12,8 +12,8 @@ class Scorer(object): - def __init__(self, dict, wer_output_filter=None): - self.dict = dict + def __init__(self, dictionary, wer_output_filter=None): + self.dictionary = dictionary self.ordered_utt_list = None self.word_filters = [] self.parse_wer_output_filter(wer_output_filter) @@ -57,7 +57,7 @@ def add_prediction(self, utt_id, pred, bpe_symbol=None): 'Duplicated utterance id detected: {}'.format(utt_id) self.char_results[utt_id] = pred + '\n' - pred_words = self.dict.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) + pred_words = self.dictionary.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) assert utt_id not in self.results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.results[utt_id] = pred_words + '\n' @@ -71,7 +71,7 @@ def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): raise TypeError('pred must be a string(got {})'.format(type(pred))) # filter out any non_lang_syms from ref and pred - non_lang_syms = getattr(self.dict, 'non_lang_syms', None) + non_lang_syms = getattr(self.dictionary, 'non_lang_syms', None) assert non_lang_syms is None or isinstance(non_lang_syms, list) if non_lang_syms is not None and len(non_lang_syms) > 0: ref_list, pred_list = ref.strip().split(), pred.strip().split() @@ -85,8 +85,8 @@ def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): self.char_counter += counter # word level counts - ref_words = self.dict.tokens_to_sentence(ref, use_unk_sym=False, bpe_symbol=bpe_symbol) - pred_words = self.dict.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) + ref_words = self.dictionary.tokens_to_sentence(ref, use_unk_sym=False, bpe_symbol=bpe_symbol) + pred_words = self.dictionary.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) # filter words according to self.word_filters (support re.sub only) for pattern, repl in self.word_filters: diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 8ded8cdd9..4565d465e 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -252,7 +252,7 @@ if [ $stage -le 6 ]; then --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_with_wer \ diff --git a/tests/espresso/test_speech_dataset.py b/tests/espresso/test_speech_dataset.py index 64be72531..79818cc8c 100644 --- a/tests/espresso/test_speech_dataset.py +++ b/tests/espresso/test_speech_dataset.py @@ -90,7 +90,7 @@ def setUp(self): self.num_transripts = 100 self.batch_size = 8 self.cache_size = 16 - self.dict = self.make_dictionary() + self.dictionary = self.make_dictionary() self.expected_feats = self.generate_feats( self.test_dir, num=self.num_audios, seed=0, ) @@ -115,12 +115,12 @@ def _speech_dataset_helper( ) tgt_dataset = AsrTextDataset( path=os.path.join(self.test_dir, 'text_tokens'), - dictionary=self.dict, + dictionary=self.dictionary, ) dataset = SpeechDataset( src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset.sizes, self.dict, + tgt_dataset, tgt_dataset.sizes, self.dictionary, left_pad_source=False, left_pad_target=False, max_source_positions=1000, @@ -151,7 +151,7 @@ def _speech_dataset_helper( self.assertEqual(bsz, len(batch_sampler[i])) src_frames = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] - tgt_tokens = self.dict.string(batch["target"]).split('\n') + tgt_tokens = self.dictionary.string(batch["target"]).split('\n') tgt_tokens = [line.split(' ') for line in tgt_tokens] self.assertEqual(bsz, src_frames.size(0)) self.assertEqual(bsz, src_lengths.numel()) diff --git a/tests/espresso/test_speech_utils.py b/tests/espresso/test_speech_utils.py index e8213d997..bbe991e48 100644 --- a/tests/espresso/test_speech_utils.py +++ b/tests/espresso/test_speech_utils.py @@ -60,7 +60,7 @@ def setUp(self): self.oovs = list(string.ascii_uppercase) self.non_lang_syms = ['', '', ''] self.num_sentences = 100 - self.dict = self.make_dictionary( + self.dictionary = self.make_dictionary( self.vocab, non_lang_syms=self.non_lang_syms, ) @@ -73,31 +73,31 @@ def test_speech_tokenizer(self): print('test sentence {}:'.format(i)) print(sent) tokens = utils.tokenize( - sent, space=self.dict.space_word, + sent, space=self.dictionary.space_word, non_lang_syms=self.non_lang_syms, ) # test :func:`~speech_tools.utils.tokenize` with # :func:`~AsrDictionary.encode_line` - tensor = self.dict.encode_line( + tensor = self.dictionary.encode_line( tokens, add_if_not_exist=False, append_eos=True, ) - reconstructed_tokens = self.dict.string(tensor) + reconstructed_tokens = self.dictionary.string(tensor) expected_tokens = ' '.join( - [token if self.dict.index(token) != self.dict.unk() else - self.dict.unk_word for token in tokens.split(' ')] + [token if self.dictionary.index(token) != self.dictionary.unk() else + self.dictionary.unk_word for token in tokens.split(' ')] ) self.assertEqual(reconstructed_tokens, expected_tokens) # test :func:`~speech_tools.utils.tokenize` with # :func:`~AsrDictionary.tokens_to_sentence` - reconstructed_sent = self.dict.tokens_to_sentence(tokens) + reconstructed_sent = self.dictionary.tokens_to_sentence(tokens) expected_sent = [] words = sent.split(' ') for w in words: if w not in self.non_lang_syms: new_word = ''.join( - [self.dict.unk_word if c in self.oovs else c for c in w] + [self.dictionary.unk_word if c in self.oovs else c for c in w] ) expected_sent.append(new_word) else: From 13e2d6289bfc797cbf3ff19daf9dc01a9fb5871d Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 25 Dec 2019 19:20:30 -0500 Subject: [PATCH 055/119] code adaptation/changes according to the commits from Dec 21 to Dec 24, 2019 --- .../label_smoothed_cross_entropy_with_wer.py | 7 +++-- espresso/speech_train.py | 29 +++++++++++++++++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index ed87bfdf0..4cda92afd 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -34,9 +34,10 @@ def label_smoothed_nll_loss( else: raise ValueError('Unsupported smoothing type: {}'.format(smoothing_type)) if ignore_index is not None: - non_pad_mask = target.ne(ignore_index) - nll_loss = nll_loss[non_pad_mask] - smooth_loss = smooth_loss[non_pad_mask] + pad_mask = target.eq(ignore_index) + if pad_mask.any(): + nll_loss.masked_fill_(pad_mask, 0.) + smooth_loss.masked_fill_(pad_mask, 0.) else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index fa84a59b3..08f0f3b0a 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -82,8 +82,11 @@ def main(args, init_distributed=False): valid_subsets = args.valid_subset.split(',') while ( lr > args.min_lr - and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch - and epoch_itr._next_epoch_itr is not None)) + and ( + epoch_itr.epoch < max_epoch + # allow resuming training from the final checkpoint + or epoch_itr._next_epoch_itr is not None + ) and trainer.get_num_updates() < max_update ): # train for one epoch @@ -101,6 +104,11 @@ def main(args, init_distributed=False): if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + # early stop + if should_stop_early(args, valid_losses[0]): + print('| Early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) + break + reload_dataset = len(args.train_feat_files) > 1 # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) @@ -108,6 +116,23 @@ def main(args, init_distributed=False): print('| done training in {:.1f} seconds'.format(train_meter.sum)) +def should_stop_early(args, valid_loss): + if args.patience <= 0: + return False + + def is_better(a, b): + return a > b if args.maximize_best_checkpoint_metric else a < b + + prev_best = getattr(should_stop_early, 'best', None) + if prev_best is None or is_better(valid_loss, prev_best): + should_stop_early.best = valid_loss + should_stop_early.num_runs = 0 + return False + else: + should_stop_early.num_runs += 1 + return should_stop_early.num_runs > args.patience + + def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches From 390d157a4b08ca5efd66439588ace6b388fddfb2 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 25 Dec 2019 20:31:52 -0500 Subject: [PATCH 056/119] move the code of computing prob mask of temporal label smoothing into a separate function --- .../label_smoothed_cross_entropy_with_wer.py | 51 +++++++++++-------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index 4cda92afd..df5b3b6ff 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -16,6 +16,33 @@ from espresso.tools import wer +def temporal_label_smoothing_prob_mask( + lprobs: torch.Tensor, # R[Batch, SeqLength, Vocab] + target: torch.Tensor, # Z[Batch, SeqLength] + padding_index: int = 0, +): + # see https://arxiv.org/pdf/1612.02695.pdf + # prob_mask.dtype=int for deterministic behavior of Tensor.scatter_add_() + prob_mask = torch.zeros_like(lprobs, dtype=torch.int) # bsz x tgtlen x vocab_size + idx_tensor = target.new_full(target.size(), padding_index).unsqueeze(-1) # bsz x tgtlen x 1 + # hard-code the remaining probabilty mass distributed symmetrically + # over neighbors at distance ±1 and ±2 with a 5 : 2 ratio + idx_tensor[:, 2:, 0] = target[:, :-2] # two neighbors to the left + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) + idx_tensor.fill_(padding_index)[:, 1:, 0] = target[:, :-1] + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) + idx_tensor.fill_(padding_index)[:, :-2, 0] = target[:, 2:] # two neighbors to the right + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) + idx_tensor.fill_(padding_index)[:, :-1, 0] = target[:, 1:] + prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) + prob_mask[:, :, padding_index] = 0 # clear cumulative count on + prob_mask = prob_mask.float() # convert to float + sum_prob = prob_mask.sum(-1, keepdim=True) + sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem + prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1)) + return prob_mask + + def label_smoothed_nll_loss( lprobs, target, epsilon, ignore_index=None, reduce=True, smoothing_type='uniform', prob_mask=None, unigram_tensor=None, @@ -192,27 +219,9 @@ def forward(self, model, sample, reduce=True): print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) # word error stats code ends - prob_mask = None - if self.args.smoothing_type == 'temporal': - # see https://arxiv.org/pdf/1612.02695.pdf - # prob_mask.dtype=int for deterministic behavior of Tensor.scatter_add_() - prob_mask = torch.zeros_like(lprobs, dtype=torch.int) # bsz x tgtlen x vocab_size - idx_tensor = target.new_full(target.size(), self.padding_idx).unsqueeze(-1) # bsz x tgtlen x 1 - # hard-code the remaining probabilty mass distributed symmetrically - # over neighbors at distance ±1 and ±2 with a 5 : 2 ratio - idx_tensor[:, 2:, 0] = target[:, :-2] # two neighbors to the left - prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) - idx_tensor.fill_(self.padding_idx)[:, 1:, 0] = target[:, :-1] - prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) - idx_tensor.fill_(self.padding_idx)[:, :-2, 0] = target[:, 2:] # two neighbors to the right - prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([2]).expand_as(idx_tensor)) - idx_tensor.fill_(self.padding_idx)[:, :-1, 0] = target[:, 1:] - prob_mask.scatter_add_(-1, idx_tensor, prob_mask.new([5]).expand_as(idx_tensor)) - prob_mask[:, :, self.padding_idx] = 0 # clear cumulative count on - prob_mask = prob_mask.float() # convert to float - sum_prob = prob_mask.sum(-1, keepdim=True) - sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem - prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1)) + prob_mask = temporal_label_smoothing_prob_mask( + lprobs, target, padding_index=self.padding_idx, + ) if self.args.smoothing_type == 'temporal' else None lprobs = lprobs.view(-1, lprobs.size(-1)) target = target.view(-1, 1) From 3c5b3c97a28aa04a613279d7b7f0192b817a909c Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Thu, 26 Dec 2019 17:19:22 -0500 Subject: [PATCH 057/119] isolate greedy search code from criterions (#19) * isolate greedy search code from criterions --- espresso/criterions/cross_entropy_with_wer.py | 116 ++++-------------- .../label_smoothed_cross_entropy_with_wer.py | 115 ++++------------- espresso/tools/simple_greedy_decoder.py | 116 ++++++++++++++++++ 3 files changed, 168 insertions(+), 179 deletions(-) create mode 100644 espresso/tools/simple_greedy_decoder.py diff --git a/espresso/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py index 78889e3d2..85cb8278d 100644 --- a/espresso/criterions/cross_entropy_with_wer.py +++ b/espresso/criterions/cross_entropy_with_wer.py @@ -4,17 +4,16 @@ # LICENSE file in the root directory of this source tree. import numpy as np -import torch import torch.nn.functional as F from fairseq import utils from fairseq.data import data_utils -from fairseq.models import FairseqIncrementalDecoder from fairseq.criterions import register_criterion from fairseq.criterions.cross_entropy import CrossEntropyCriterion from espresso.tools import wer +from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder @register_criterion('cross_entropy_with_wer') @@ -25,6 +24,7 @@ def __init__(self, args, task): dictionary = task.target_dictionary self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) + self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 @@ -55,81 +55,11 @@ def forward(self, model, sample, reduce=True): net_output = model(**sample['net_input'], epoch=self.epoch) lprobs = model.get_normalized_probs(net_output, log_probs=True) target = model.get_targets(sample, net_output) - else: - assert isinstance(model.decoder, FairseqIncrementalDecoder) - incremental_states = {} - encoder_input = { - k: v for k, v in sample['net_input'].items() - if k != 'prev_output_tokens' - } - encoder_out = model.encoder(**encoder_input) - target = sample['target'] - # make the maximum decoding length equal to at least the length of - # target, and the length of encoder_out if possible - maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) - tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) - tokens[:, 0] = dictionary.eos() - lprobs = [] - attn = [] if getattr(model.decoder, 'need_attn', False) else None - dummy_log_probs = encoder_out['encoder_out'][0].new_full( - [target.size(0), len(dictionary)], -np.log(len(dictionary))) - for step in range(maxlen + 1): # one extra step for EOS marker - is_eos = tokens[:, step].eq(dictionary.eos()) - # if all predictions are finished (i.e., ended with eos), - # pad lprobs to target length with dummy log probs, - # truncate tokens up to this step and break - if step > 0 and is_eos.sum() == is_eos.size(0): - for _ in range(step, target.size(1)): - lprobs.append(dummy_log_probs) - tokens = tokens[:, :step + 1] - break - log_probs, attn_scores = self._decode( - tokens[:, :step + 1], model, encoder_out, incremental_states, - ) - tokens[:, step + 1] = log_probs.argmax(-1) - if step > 0: # deal with finished predictions - # make log_probs uniform if the previous output token is EOS - # and add consecutive EOS to the end of prediction - log_probs[is_eos, :] = -np.log(log_probs.size(1)) - tokens[is_eos, step + 1] = dictionary.eos() - if step < target.size(1): - lprobs.append(log_probs) - if getattr(model.decoder, 'need_attn', False): - attn.append(attn_scores) - # bsz x min(tgtlen, maxlen + 1) x vocab_size - lprobs = torch.stack(lprobs, dim=1) - if getattr(model.decoder, 'need_attn', False): - # bsz x (maxlen + 1) x (length of encoder_out) - attn = torch.stack(attn, dim=1) - # word error stats code starts - if ( - not model.training or - ( + if ( self.num_updates // self.args.print_interval > (self.num_updates - 1) // self.args.print_interval - ) - ): - pred = lprobs.argmax(-1).cpu() if model.training else \ - tokens[:, 1:].data.cpu() # bsz x len - - if not model.training: # validation step, compute WER stats with scorer - assert pred.size(0) == target.size(0) - self.scorer.reset() - for i in range(target.size(0)): - utt_id = sample['utt_id'][i] - id = sample['id'].data[i].item() - # ref_tokens = dictionary.string(target.data[i]) - # if it is a dummy batch (e.g., a "padding" batch in a sharded - # dataset), id might exceeds the dataset size; in that case we - # just skip it - if id < len(self.valid_tgt_dataset): - ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dictionary.string(pred.data[i]) - self.scorer.add_evaluation( - utt_id, ref_tokens, pred_tokens, - bpe_symbol=self.args.remove_bpe, - ) - else: # print a randomly sampled result every print_interval updates + ): # print a randomly sampled result every print_interval updates + pred = lprobs.argmax(-1).cpu() # bsz x len assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) @@ -145,7 +75,27 @@ def forward(self, model, sample, reduce=True): ) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) - # word error stats code ends + else: + tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample) + pred = tokens[:, 1:].data.cpu() # bsz x len + target = sample['target'] + # compute word error stats + assert pred.size(0) == target.size(0) + self.scorer.reset() + for i in range(target.size(0)): + utt_id = sample['utt_id'][i] + id = sample['id'].data[i].item() + # ref_tokens = dictionary.string(target.data[i]) + # if it is a dummy batch (e.g., a "padding" batch in a sharded + # dataset), id might exceeds the dataset size; in that case we + # just skip it + if id < len(self.valid_tgt_dataset): + ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) + pred_tokens = dictionary.string(pred.data[i]) + self.scorer.add_evaluation( + utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, + ) + lprobs = lprobs.view(-1, lprobs.size(-1)) loss = F.nll_loss( lprobs, @@ -184,20 +134,6 @@ def aggregate_logging_outputs(logging_outputs): agg_output['char_count'] = char_count return agg_output - def _decode(self, tokens, model, encoder_out, incremental_states): - decoder_out = list(model.forward_decoder( - tokens, encoder_out=encoder_out, incremental_state=incremental_states, - )) - decoder_out[0] = decoder_out[0][:, -1:, :] - attn = decoder_out[1] - if type(attn) is dict: - attn = attn.get('attn', None) - if attn is not None: - attn = attn[:, -1, :] - probs = model.get_normalized_probs(decoder_out, log_probs=True) - probs = probs[:, -1, :] - return probs, attn - def set_train_tgt_dataset(self, dataset): self.train_tgt_dataset = dataset diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index df5b3b6ff..7183e9a5b 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -8,12 +8,12 @@ from fairseq import utils from fairseq.data import data_utils -from fairseq.models import FairseqIncrementalDecoder from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from espresso.tools import wer +from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder def temporal_label_smoothing_prob_mask( @@ -84,6 +84,7 @@ def __init__(self, args, task): dictionary = task.target_dictionary self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) + self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) self.train_tgt_dataset = None self.valid_tgt_dataset = None self.num_updates = -1 @@ -128,81 +129,11 @@ def forward(self, model, sample, reduce=True): net_output = model(**sample['net_input'], epoch=self.epoch) lprobs = model.get_normalized_probs(net_output, log_probs=True) target = model.get_targets(sample, net_output) - else: - assert isinstance(model.decoder, FairseqIncrementalDecoder) - incremental_states = {} - encoder_input = { - k: v for k, v in sample['net_input'].items() - if k != 'prev_output_tokens' - } - encoder_out = model.encoder(**encoder_input) - target = sample['target'] - # make the maximum decoding length equal to at least the length of - # target, and the length of encoder_out if possible - maxlen = max(encoder_out['encoder_out'][0].size(1), target.size(1)) - tokens = target.new_full([target.size(0), maxlen + 2], self.padding_idx) - tokens[:, 0] = dictionary.eos() - lprobs = [] - attn = [] if getattr(model.decoder, 'need_attn', False) else None - dummy_log_probs = encoder_out['encoder_out'][0].new_full( - [target.size(0), len(dictionary)], -np.log(len(dictionary))) - for step in range(maxlen + 1): # one extra step for EOS marker - is_eos = tokens[:, step].eq(dictionary.eos()) - # if all predictions are finished (i.e., ended with eos), - # pad lprobs to target length with dummy log probs, - # truncate tokens up to this step and break - if step > 0 and is_eos.sum() == is_eos.size(0): - for _ in range(step, target.size(1)): - lprobs.append(dummy_log_probs) - tokens = tokens[:, :step + 1] - break - log_probs, attn_scores = self._decode( - tokens[:, :step + 1], model, encoder_out, incremental_states, - ) - tokens[:, step + 1] = log_probs.argmax(-1) - if step > 0: # deal with finished predictions - # make log_probs uniform if the previous output token is EOS - # and add consecutive EOS to the end of prediction - log_probs[is_eos, :] = -np.log(log_probs.size(1)) - tokens[is_eos, step + 1] = dictionary.eos() - if step < target.size(1): - lprobs.append(log_probs) - if getattr(model.decoder, 'need_attn', False): - attn.append(attn_scores) - # bsz x min(tgtlen, maxlen + 1) x vocab_size - lprobs = torch.stack(lprobs, dim=1) - if getattr(model.decoder, 'need_attn', False): - # bsz x (maxlen + 1) x (length of encoder_out) - attn = torch.stack(attn, dim=1) - # word error stats code starts - if ( - not model.training or - ( + if ( self.num_updates // self.args.print_interval > (self.num_updates - 1) // self.args.print_interval - ) - ): - pred = lprobs.argmax(-1).cpu() if model.training else \ - tokens[:, 1:].data.cpu() # bsz x len - - if not model.training: # validation step, compute WER stats with scorer - assert pred.size(0) == target.size(0) - self.scorer.reset() - for i in range(target.size(0)): - utt_id = sample['utt_id'][i] - id = sample['id'].data[i].item() - # ref_tokens = dictionary.string(target.data[i]) - # if it is a dummy batch (e.g., a "padding" batch in a sharded - # dataset), id might exceeds the dataset size; in that case we - # just skip it - if id < len(self.valid_tgt_dataset): - ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dictionary.string(pred.data[i]) - self.scorer.add_evaluation( - utt_id, ref_tokens, pred_tokens, - bpe_symbol=self.args.remove_bpe, - ) - else: # print a randomly sampled result every print_interval updates + ): # print a randomly sampled result every print_interval updates + pred = lprobs.argmax(-1).cpu() # bsz x len assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) @@ -218,7 +149,27 @@ def forward(self, model, sample, reduce=True): ) print('| sample REF: ' + ref_one) print('| sample PRD: ' + pred_one) - # word error stats code ends + else: + tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample) + pred = tokens[:, 1:].data.cpu() # bsz x len + target = sample['target'] + # compute word error stats + assert pred.size(0) == target.size(0) + self.scorer.reset() + for i in range(target.size(0)): + utt_id = sample['utt_id'][i] + id = sample['id'].data[i].item() + # ref_tokens = dictionary.string(target.data[i]) + # if it is a dummy batch (e.g., a "padding" batch in a sharded + # dataset), id might exceeds the dataset size; in that case we + # just skip it + if id < len(self.valid_tgt_dataset): + ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) + pred_tokens = dictionary.string(pred.data[i]) + self.scorer.add_evaluation( + utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, + ) + prob_mask = temporal_label_smoothing_prob_mask( lprobs, target, padding_index=self.padding_idx, ) if self.args.smoothing_type == 'temporal' else None @@ -261,20 +212,6 @@ def aggregate_logging_outputs(logging_outputs): agg_output['char_count'] = char_count return agg_output - def _decode(self, tokens, model, encoder_out, incremental_states): - decoder_out = list(model.forward_decoder( - tokens, encoder_out=encoder_out, incremental_state=incremental_states, - )) - decoder_out[0] = decoder_out[0][:, -1:, :] - attn = decoder_out[1] - if type(attn) is dict: - attn = attn.get('attn', None) - if attn is not None: - attn = attn[:, -1, :] - probs = model.get_normalized_probs(decoder_out, log_probs=True) - probs = probs[:, -1, :] - return probs, attn - def set_train_tgt_dataset(self, dataset): self.train_tgt_dataset = dataset diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py new file mode 100644 index 000000000..acb929ce4 --- /dev/null +++ b/espresso/tools/simple_greedy_decoder.py @@ -0,0 +1,116 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +import torch + + +class SimpleGreedyDecoder(object): + def __init__( + self, dictionary, max_len_a=0, max_len_b=200, temperature=1., for_validation=True, + ): + """Decode given speech audios with the simple greedy search. + + Args: + dictionary (~fairseq.data.Dictionary): dictionary + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + for_validation (bool, optional): indicate whether the decoder is + used for validation. It affects how max_len is determined, and + whether a tensor of lprobs is returned. If true, target should be + not None + """ + self.pad = dictionary.pad() + self.unk = dictionary.unk() + self.eos = dictionary.eos() + self.vocab_size = len(dictionary) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.temperature = temperature + assert temperature > 0, '--temperature must be greater than 0' + self.for_validation = for_validation + + @torch.no_grad() + def decode(self, models, sample, **kwargs): + """Generate a batch of translations. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + from fairseq.sequence_generator import EnsembleModel + model = EnsembleModel(models) + return self._decode(model, sample, **kwargs) + + @torch.no_grad() + def _decode(self, model, sample, bos_token=None, **kwargs): + model.eval() + + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SimpleGreedyDecoder directly calls model.encoder + encoder_input = { + k: v for k, v in sample["net_input"].items() + if k != "prev_output_tokens" + } + src_tokens = encoder_input["src_tokens"] + input_size = src_tokens.size() + # batch dimension goes first followed by source lengths + bsz = input_size[0] + src_len = input_size[1] + + encoder_outs = model.forward_encoder(encoder_input) + target = sample["target"] + # target can only be None if not for validation + assert target is not None or not self.for_validation + max_encoder_output_length = encoder_outs[0]["encoder_out"][0].size(0) + # for validation, make the maximum decoding length equal to at least the + # length of target, and the length of encoder_out if possible; otherwise + # max_len is obtained from max_len_a/b + max_len = max(max_encoder_output_length, target.size(1)) \ + if self.for_validation else \ + min( + int(self.max_len_a * src_len + self.max_len_b), + # exclude the EOS marker + model.max_decoder_positions() - 1, + ) + + tokens = src_tokens.new(bsz, max_len + 2).long().fill_(self.pad) + tokens[:, 0] = self.eos if bos_token is None else bos_token + # lprobs is only used when target is not None (i.e., for validation) + lprobs = encoder_outs[0]["encoder_out"][0].new_full( + (bsz, target.size(1), self.vocab_size), -np.log(self.vocab_size), + ) if self.for_validation else None + attn = None + for step in range(max_len + 1): # one extra step for EOS marker + is_eos = tokens[:, step].eq(self.eos) + if step > 0 and is_eos.sum() == is_eos.size(0): + # all predictions are finished (i.e., ended with eos) + tokens = tokens[:, :step + 1] + if attn is not None: + attn = attn[:, :, :step + 1] + break + log_probs, avg_attn_scores = model.forward_decoder( + tokens[:, :step + 1], encoder_outs, temperature=self.temperature, + ) + tokens[:, step + 1] = log_probs.argmax(-1) + if step > 0: # deal with finished predictions + # make log_probs uniform if the previous output token is EOS + # and add consecutive EOS to the end of prediction + log_probs[is_eos, :] = -np.log(log_probs.size(1)) + tokens[is_eos, step + 1] = self.eos + if self.for_validation and step < target.size(1): + lprobs[:, step, :] = log_probs + if avg_attn_scores is not None: + if attn is None: + attn = avg_attn_scores.new(bsz, max_encoder_output_length, max_len + 2) + attn[:, :, step + 1].copy_(avg_attn_scores) + + return tokens, lprobs, attn From 54870cb485d984af7d3885c39072de3bcb184893 Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Fri, 27 Dec 2019 15:43:04 -0500 Subject: [PATCH 058/119] remove the need to pass tgt dataset to criterions by adding a raw text field to collated samples (#20) --- espresso/criterions/cross_entropy_with_wer.py | 33 +++++-------------- .../label_smoothed_cross_entropy_with_wer.py | 33 +++++-------------- espresso/data/scp_text_dataset.py | 29 ++++++++-------- espresso/data/speech_dataset.py | 13 +++++++- espresso/models/speech_lstm.py | 6 ++-- espresso/speech_recognize.py | 6 ++-- espresso/speech_train.py | 8 +---- espresso/tasks/speech_recognition.py | 6 ++-- espresso/tools/Makefile | 10 +++--- 9 files changed, 61 insertions(+), 83 deletions(-) diff --git a/espresso/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py index 85cb8278d..87b4d536a 100644 --- a/espresso/criterions/cross_entropy_with_wer.py +++ b/espresso/criterions/cross_entropy_with_wer.py @@ -25,8 +25,6 @@ def __init__(self, args, task): dictionary = task.target_dictionary self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) - self.train_tgt_dataset = None - self.valid_tgt_dataset = None self.num_updates = -1 self.epoch = 0 @@ -63,14 +61,13 @@ def forward(self, model, sample, reduce=True): assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) - id = sample['id'].data[i].item() + ref_tokens = sample['target_raw_text'][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - # ref_one = dictionary.tokens_to_sentence(dictionary.string(target.data[i])) - ref_one = self.train_tgt_dataset.get_original_text( - id, dictionary, bpe_symbol=self.args.remove_bpe, + ref_one = dictionary.tokens_to_sentence( + ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, ) pred_one = dictionary.tokens_to_sentence( - dictionary.string(pred.data[i][:length]), + dictionary.string(pred.data[i][:length]), use_unk_sym=True, bpe_symbol=self.args.remove_bpe, ) print('| sample REF: ' + ref_one) @@ -84,17 +81,11 @@ def forward(self, model, sample, reduce=True): self.scorer.reset() for i in range(target.size(0)): utt_id = sample['utt_id'][i] - id = sample['id'].data[i].item() - # ref_tokens = dictionary.string(target.data[i]) - # if it is a dummy batch (e.g., a "padding" batch in a sharded - # dataset), id might exceeds the dataset size; in that case we - # just skip it - if id < len(self.valid_tgt_dataset): - ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dictionary.string(pred.data[i]) - self.scorer.add_evaluation( - utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, - ) + ref_tokens = sample['target_raw_text'][i] + pred_tokens = dictionary.string(pred.data[i]) + self.scorer.add_evaluation( + utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, + ) lprobs = lprobs.view(-1, lprobs.size(-1)) loss = F.nll_loss( @@ -134,12 +125,6 @@ def aggregate_logging_outputs(logging_outputs): agg_output['char_count'] = char_count return agg_output - def set_train_tgt_dataset(self, dataset): - self.train_tgt_dataset = dataset - - def set_valid_tgt_dataset(self, dataset): - self.valid_tgt_dataset = dataset - def set_num_updates(self, num_updates): self.num_updates = num_updates diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index 7183e9a5b..5d022a6ce 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -85,8 +85,6 @@ def __init__(self, args, task): dictionary = task.target_dictionary self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) - self.train_tgt_dataset = None - self.valid_tgt_dataset = None self.num_updates = -1 self.epoch = 0 self.unigram_tensor = None @@ -137,14 +135,13 @@ def forward(self, model, sample, reduce=True): assert pred.size() == target.size() with data_utils.numpy_seed(self.num_updates): i = np.random.randint(0, len(sample['id'])) - id = sample['id'].data[i].item() + ref_tokens = sample['target_raw_text'][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - # ref_one = dictionary.tokens_to_sentence(dictionary.string(target.data[i])) - ref_one = self.train_tgt_dataset.get_original_text( - id, dictionary, bpe_symbol=self.args.remove_bpe, + ref_one = dictionary.tokens_to_sentence( + ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, ) pred_one = dictionary.tokens_to_sentence( - dictionary.string(pred.data[i][:length]), + dictionary.string(pred.data[i][:length]), use_unk_sym=True, bpe_symbol=self.args.remove_bpe, ) print('| sample REF: ' + ref_one) @@ -158,17 +155,11 @@ def forward(self, model, sample, reduce=True): self.scorer.reset() for i in range(target.size(0)): utt_id = sample['utt_id'][i] - id = sample['id'].data[i].item() - # ref_tokens = dictionary.string(target.data[i]) - # if it is a dummy batch (e.g., a "padding" batch in a sharded - # dataset), id might exceeds the dataset size; in that case we - # just skip it - if id < len(self.valid_tgt_dataset): - ref_tokens = self.valid_tgt_dataset.get_original_tokens(id) - pred_tokens = dictionary.string(pred.data[i]) - self.scorer.add_evaluation( - utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, - ) + ref_tokens = sample['target_raw_text'][i] + pred_tokens = dictionary.string(pred.data[i]) + self.scorer.add_evaluation( + utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, + ) prob_mask = temporal_label_smoothing_prob_mask( lprobs, target, padding_index=self.padding_idx, @@ -212,12 +203,6 @@ def aggregate_logging_outputs(logging_outputs): agg_output['char_count'] = char_count return agg_output - def set_train_tgt_dataset(self, dataset): - self.train_tgt_dataset = dataset - - def set_valid_tgt_dataset(self, dataset): - self.valid_tgt_dataset = dataset - def set_num_updates(self, num_updates): self.num_updates = num_updates diff --git a/espresso/data/scp_text_dataset.py b/espresso/data/scp_text_dataset.py index c976e78e0..eb068a42c 100644 --- a/espresso/data/scp_text_dataset.py +++ b/espresso/data/scp_text_dataset.py @@ -15,7 +15,12 @@ class ScpDataset(torch.utils.data.Dataset): - """Loader for TorchNet IndexedDataset""" + """ + A dataset for audio features prepared in Kaldi scp format (e.g., feats.scp). + See http://kaldi-asr.org/doc/tutorial_running.html#tutorial_running_feats + for the format descriptions. This class loads a feature matrix from the disk + every time each entry is inquired, thus incurs the most intensive I/O. + """ def __init__(self, path): super().__init__() @@ -74,6 +79,11 @@ def exists(path): class ScpCachedDataset(ScpDataset): + """ + This class loads a batch of feature matrices (specified as *cache_size*) + every time an entry is inquired. The inquire order should be known in advance. + It balances the I/O efficiency and memory usage. + """ def __init__(self, path, ordered_prefetch=False, cache_size=4096): super().__init__(path) @@ -143,7 +153,10 @@ def __getitem__(self, i): class ScpInMemoryDataset(ScpDataset): - """Loader for TorchNet ScpDataset, keeps all the data in memory.""" + """ + This class loads all feature matrices into memory at once. + It has the maximum memory usage and least I/O. + """ def __init__(self, path): super().__init__(path) @@ -223,17 +236,7 @@ def filter_and_reorder(self, indices): def __getitem__(self, i): self.check_index(i) - return self.tensor_list[i] - - def get_original_tokens(self, i): - self.check_index(i) - return self.tokens_list[i] - - def get_original_text(self, i, dictionary, bpe_symbol=None): - self.check_index(i) - return dictionary.tokens_to_sentence( - self.tokens_list[i], use_unk_sym=False, bpe_symbol=bpe_symbol, - ) + return self.tensor_list[i], self.tokens_list[i] def __len__(self): return self.size diff --git a/espresso/data/speech_dataset.py b/espresso/data/speech_dataset.py index 8bd6e82c3..8dd1c6b01 100644 --- a/espresso/data/speech_dataset.py +++ b/espresso/data/speech_dataset.py @@ -60,6 +60,11 @@ def merge(key, left_pad, move_eos_to_beginning=False): else: ntokens = sum(s['source'].size(0) for s in samples) + target_raw_text = None + if samples[0].get('target_raw_text', None) is not None: + target_raw_text = [s['target_raw_text'] for s in samples] + target_raw_text = [target_raw_text[i] for i in sort_order.numpy()] + batch = { 'id': id, 'utt_id': utt_id, @@ -70,6 +75,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): 'src_lengths': src_lengths, }, 'target': target, + 'target_raw_text': target_raw_text, } if prev_output_tokens is not None: batch['net_input']['prev_output_tokens'] = prev_output_tokens @@ -144,13 +150,15 @@ def _match_src_tgt(self): assert self.src.utt_ids == self.tgt.utt_ids def __getitem__(self, index): - tgt_item = self.tgt[index] if self.tgt is not None else None + tgt_item = self.tgt[index][0] if self.tgt is not None else None + raw_text_item = self.tgt[index][1] if self.tgt is not None else None src_item = self.src[index] example = { 'id': index, 'utt_id': self.src.utt_ids[index], 'source': src_item, 'target': tgt_item, + 'target_raw_text': raw_text_item, } return example @@ -167,6 +175,8 @@ def collater(self, samples): dict: a mini-batch with the following keys: - `id` (LongTensor): example IDs in the original input order + - `utt_id` (List[str]): list of utterance ids + - `nsentences` (int): batch size - `ntokens` (int): total number of tokens in the batch - `net_input` (dict): the input to the Model, containing keys: @@ -185,6 +195,7 @@ def collater(self, samples): - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. + - `target_raw_text` (List[str]): list of original text """ return collate( samples, pad_idx=self.dictionary.pad(), eos_idx=self.dictionary.eos(), diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 1168f6945..b76cb7c3f 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -228,7 +228,7 @@ def eval_str_nested_list_or_tuple(x, type=int): if args.pretrained_lm_checkpoint: print('| loading pretrained LM from {}'.format(args.pretrained_lm_checkpoint)) pretrained_lm = checkpoint_utils.load_model_ensemble( - args.pretrained_lm_checkpoint, task)[0][0] + args.pretrained_lm_checkpoint, task=task)[0][0] pretrained_lm.make_generation_fast_() # freeze pretrained model for param in pretrained_lm.parameters(): @@ -584,7 +584,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, if epoch > 0: sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) if sampling_prob < 1.0: # apply scheduled sampling - return self._forward_with_schduled_sampling( + return self._forward_with_scheduled_sampling( prev_output_tokens, sampling_prob, encoder_out=encoder_out, incremental_state={}, # use empty dict to preserve forward state ) @@ -594,7 +594,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, ) return self.output_layer(x), attn_scores - def _forward_with_schduled_sampling( + def _forward_with_scheduled_sampling( self, prev_output_tokens, sampling_prob, encoder_out=None, incremental_state=None, ): bsz, seqlen = prev_output_tokens.size() diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index d7cc874e6..cc65f46bb 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -96,7 +96,7 @@ def main(args): else (None, model.max_positions()) for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=8, + required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, @@ -138,13 +138,13 @@ def main(args): output_lengths = models[0].encoder.output_lengths(net_input['src_lengths']) nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) - for i, sample_id in enumerate(sample['id'].tolist()): + for i in range(len(sample['id'])): has_target = sample['target'] is not None utt_id = sample['utt_id'][i] # Retrieve the original sentences if has_target: - target_str = task.dataset(args.gen_subset).tgt.get_original_tokens(sample_id) + target_str = sample['target_raw_text'][i] if not args.quiet: target_sent = dictionary.tokens_to_sentence( target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 08f0f3b0a..35ed339f7 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -70,9 +70,6 @@ def main(args, init_distributed=False): # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) - if hasattr(trainer.criterion, 'set_train_tgt_dataset'): - trainer.criterion.set_train_tgt_dataset(task.dataset(args.train_subset).tgt) - # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf @@ -270,9 +267,6 @@ def validate(args, trainer, task, epoch_itr, subsets): meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) - if hasattr(trainer.criterion, 'set_valid_tgt_dataset'): - trainer.criterion.set_valid_tgt_dataset(task.dataset(subset).tgt) - for sample in progress: log_output = trainer.valid_step(sample) @@ -351,7 +345,7 @@ def print_options_meaning_changes(args): def cli_main(): - parser = options.get_training_parser(default_task='speech_recognition') + parser = options.get_training_parser() parser.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, help='remove BPE tokens before scoring ' '(can be set to sentencepiece). Being used for monitoring ' diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index a23967ebf..40b8a4bab 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -130,8 +130,8 @@ def setup_task(cls, args, **kwargs): args.left_pad_target = options.eval_bool(args.left_pad_target) # load dictionaries - dict_path = os.path.join(os.path.dirname(args.text_files[0]), 'dict.txt') \ - if args.dict is None and args.text_files is not None else args.dict + dict_path = os.path.join(os.path.dirname(args.train_text_files[0]), 'dict.txt') \ + if args.dict is None and args.train_text_files is not None else args.dict assert dict_path is not None, 'Please specify --dict' dictionary = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) print('| dictionary: {} types'.format(len(dictionary))) @@ -219,7 +219,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): self.dictionary.count[self.dictionary.eos()] = len(tgt_dataset) unk_count = 0 for i in range(len(tgt_dataset)): - unk_count += (tgt_dataset[i] == self.dictionary.unk()).int().sum().item() + unk_count += (tgt_dataset[i][0] == self.dictionary.unk()).int().sum().item() self.dictionary.count[self.dictionary.unk()] = unk_count def build_generator(self, args): diff --git a/espresso/tools/Makefile b/espresso/tools/Makefile index e5fec0150..5cba5e3e9 100644 --- a/espresso/tools/Makefile +++ b/espresso/tools/Makefile @@ -6,13 +6,13 @@ all: kaldi ifneq ($(strip $(KALDI)),) kaldi: - ln -s $(KALDI) kaldi + ln -s $(KALDI) kaldi else kaldi: - test -d kaldi || git clone https://github.com/kaldi-asr/kaldi.git - cd kaldi/tools; $(MAKE) all - cd kaldi/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all + test -d kaldi || git clone https://github.com/kaldi-asr/kaldi.git + cd kaldi/tools; $(MAKE) all + cd kaldi/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all endif clean: - rm -rf kaldi + rm -rf kaldi From 3677002e90f737348a0cb90cb511f27e0c5aa6e7 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 3 Jan 2020 20:00:53 -0500 Subject: [PATCH 059/119] code adaptation/changes according to the commits from Jan 3 to Jan 6, 2020 --- espresso/data/speech_dataset.py | 6 ++-- espresso/speech_train.py | 11 +++--- espresso/tasks/speech_recognition.py | 50 ++++++++++++++++++++++++---- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/espresso/data/speech_dataset.py b/espresso/data/speech_dataset.py index 8dd1c6b01..b517e451b 100644 --- a/espresso/data/speech_dataset.py +++ b/espresso/data/speech_dataset.py @@ -32,13 +32,12 @@ def merge(key, left_pad, move_eos_to_beginning=False): raise ValueError('Invalid key.') id = torch.LongTensor([s['id'] for s in samples]) - utt_id = [s['utt_id'] for s in samples] src_frames = merge('source', left_pad=left_pad_source) # sort by descending source length src_lengths = torch.IntTensor([s['source'].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) - utt_id = [utt_id[i] for i in sort_order.numpy()] + utt_id = [samples[i]['utt_id'] for i in sort_order.numpy()] src_frames = src_frames.index_select(0, sort_order) prev_output_tokens = None @@ -62,8 +61,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): target_raw_text = None if samples[0].get('target_raw_text', None) is not None: - target_raw_text = [s['target_raw_text'] for s in samples] - target_raw_text = [target_raw_text[i] for i in sort_order.numpy()] + target_raw_text = [samples[i]['target_raw_text'] for i in sort_order.numpy()] batch = { 'id': id, diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 35ed339f7..b64acea88 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -132,15 +132,16 @@ def is_better(a, b): def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" - # Update parameters every N batches - update_freq = args.update_freq[epoch_itr.epoch - 1] \ - if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] - # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) + update_freq = ( + args.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(args.update_freq) + else args.update_freq[-1] + ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', @@ -375,7 +376,7 @@ def cli_main(): args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_rank = None # set based on device id if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': - print('| NOTE: you may get better performance with: --ddp-backend=no_c10d') + print('| NOTE: you may get faster training with: --ddp-backend=no_c10d') torch.multiprocessing.spawn( fn=distributed_main, args=(args, ), diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 40b8a4bab..6bd9d360b 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -7,7 +7,7 @@ import torch -from fairseq import options +from fairseq import options, search from fairseq.data import ConcatDataset from fairseq.tasks import FairseqTask, register_task @@ -228,6 +228,48 @@ def build_generator(self, args): print('| --score-reference is not applicable to speech recognition,' ' ignoring it.') from fairseq.sequence_generator import SequenceGenerator + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, 'sampling', False) + sampling_topk = getattr(args, 'sampling_topk', -1) + sampling_topp = getattr(args, 'sampling_topp', -1.0) + diverse_beam_groups = getattr(args, 'diverse_beam_groups', -1) + diverse_beam_strength = getattr(args, 'diverse_beam_strength', 0.5), + match_source_len = getattr(args, 'match_source_len', False) + diversity_rate = getattr(args, 'diversity_rate', -1) + if ( + sum( + int(cond) + for cond in [ + sampling, + diverse_beam_groups > 0, + match_source_len, + diversity_rate > 0, + ] + ) + > 1 + ): + raise ValueError('Provided Search parameters are mutually exclusive.') + assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' + assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' + + if sampling: + search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) + elif diverse_beam_groups > 0: + search_strategy = search.DiverseBeamSearch( + self.target_dictionary, diverse_beam_groups, diverse_beam_strength) + elif match_source_len: + # this is useful for tagging applications where the output + # length should match the input length, so we hardcode the + # length constraints for simplicity + search_strategy = search.LengthConstrainedBeamSearch( + self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, + ) + elif diversity_rate > -1: + search_strategy = search.DiverseSiblingsSearch(self.target_dictionary, diversity_rate) + else: + search_strategy = search.BeamSearch(self.target_dictionary) + return SequenceGenerator( self.target_dictionary, beam_size=getattr(args, 'beam', 5), @@ -237,14 +279,10 @@ def build_generator(self, args): normalize_scores=(not getattr(args, 'unnormalized', False)), len_penalty=getattr(args, 'lenpen', 1), unk_penalty=getattr(args, 'unkpen', 0), - sampling=getattr(args, 'sampling', False), - sampling_topk=getattr(args, 'sampling_topk', -1), - sampling_topp=getattr(args, 'sampling_topp', -1.0), temperature=getattr(args, 'temperature', 1.), - diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1), - diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5), match_source_len=getattr(args, 'match_source_len', False), no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), + search_strategy=search_strategy, eos_factor=getattr(args, 'eos_factor', None), ) From f1bed6f8b3fd6d32d0b8f38e2c406c6b87de144c Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 11 Jan 2020 18:22:13 -0500 Subject: [PATCH 060/119] code adaptation/changes according to the commits from Jan 11 to Jan 14, 2020 --- espresso/criterions/cross_entropy_with_wer.py | 19 +-- .../label_smoothed_cross_entropy_with_wer.py | 17 +- espresso/models/external_language_model.py | 2 +- .../tensorized_lookahead_language_model.py | 2 +- espresso/speech_recognize.py | 5 +- espresso/speech_train.py | 151 +++++------------- 6 files changed, 58 insertions(+), 138 deletions(-) diff --git a/espresso/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py index 87b4d536a..7c48799cc 100644 --- a/espresso/criterions/cross_entropy_with_wer.py +++ b/espresso/criterions/cross_entropy_with_wer.py @@ -4,13 +4,13 @@ # LICENSE file in the root directory of this source tree. import numpy as np -import torch.nn.functional as F -from fairseq import utils -from fairseq.data import data_utils +import torch.nn.functional as F +from fairseq import metrics, utils from fairseq.criterions import register_criterion from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.data import data_utils from espresso.tools import wer from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder @@ -97,7 +97,6 @@ def forward(self, model, sample, reduce=True): sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { 'loss': utils.item(loss.data) if reduce else loss.data, - 'nll_loss': utils.item(loss.data) if reduce else loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['target'].size(0), 'sample_size': sample_size, @@ -110,20 +109,18 @@ def forward(self, model, sample, reduce=True): return loss, sample_size, logging_output @staticmethod - def aggregate_logging_outputs(logging_outputs): + def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - agg_output = CrossEntropyCriterion.aggregate_logging_outputs(logging_outputs) + CrossEntropyCriterion.reduce_metrics(logging_outputs) + word_error = sum(log.get('word_error', 0) for log in logging_outputs) word_count = sum(log.get('word_count', 0) for log in logging_outputs) char_error = sum(log.get('char_error', 0) for log in logging_outputs) char_count = sum(log.get('char_count', 0) for log in logging_outputs) if word_count > 0: # model.training == False - agg_output['word_error'] = word_error - agg_output['word_count'] = word_count + metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) if char_count > 0: # model.training == False - agg_output['char_error'] = char_error - agg_output['char_count'] = char_count - return agg_output + metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) def set_num_updates(self, num_updates): self.num_updates = num_updates diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index 5d022a6ce..713591c7c 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -6,11 +6,10 @@ import numpy as np import torch -from fairseq import utils -from fairseq.data import data_utils - +from fairseq import metrics, utils from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion +from fairseq.data import data_utils from espresso.tools import wer from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder @@ -188,20 +187,18 @@ def forward(self, model, sample, reduce=True): return loss, sample_size, logging_output @staticmethod - def aggregate_logging_outputs(logging_outputs): + def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - agg_output = LabelSmoothedCrossEntropyCriterion.aggregate_logging_outputs(logging_outputs) + LabelSmoothedCrossEntropyCriterion.reduce_metrics(logging_outputs) + word_error = sum(log.get('word_error', 0) for log in logging_outputs) word_count = sum(log.get('word_count', 0) for log in logging_outputs) char_error = sum(log.get('char_error', 0) for log in logging_outputs) char_count = sum(log.get('char_count', 0) for log in logging_outputs) if word_count > 0: # model.training == False - agg_output['word_error'] = word_error - agg_output['word_count'] = word_count + metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) if char_count > 0: # model.training == False - agg_output['char_error'] = char_error - agg_output['char_count'] = char_count - return agg_output + metrics.log_scalar('cer', float(char_error) / char_count * 100, char_count, round=4) def set_num_updates(self, num_updates): self.num_updates = num_updates diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index 00777d0d9..d3661b5f3 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -246,7 +246,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_probs[batch_node_word_end_mask, :, self.subword_space_idx] = word_probs # take log of probs and clip it from below to avoid log(0) - out_logprobs = torch.max(out_probs, out_probs.new([self.zero])).log_() + out_logprobs = out_probs.clamp(min=self.zero).log_() # assign log-probs of emitting word to that of emitting subword out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ diff --git a/espresso/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py index 9a0370a57..a8d79af82 100644 --- a/espresso/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -212,7 +212,7 @@ def forward(self, word_probs[batch_node_word_end_mask] # take log of probs and clip it from below to avoid log(0) - out_logprobs = torch.log(torch.max(out_probs, out_probs.new([self.zero]))) + out_logprobs = torch.log(out_probs.clamp(min=self.zero)) # assign log-probs of emitting word to that of emitting subword out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index cc65f46bb..27decbb4c 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -8,6 +8,7 @@ Recognize pre-processed speech with a trained model. """ +import math import os import torch @@ -158,7 +159,8 @@ def main(args): hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) if not args.quiet: - print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, hypo['score'])) + score = hypo['score'] / math.log(2) # convert to base 2 + print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, score)) # Score and obtain attention only the top hypothesis if j == 0: @@ -177,6 +179,7 @@ def main(args): t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] + print('| NOTE: hypothesis and token scores are output in base 2') print('| Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.print_alignment: diff --git a/espresso/speech_train.py b/espresso/speech_train.py index b64acea88..5b0a72da4 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -8,17 +8,18 @@ Train a new model on one or across multiple GPUs. """ -import collections import math import random import numpy as np import torch -from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, tasks, utils +from fairseq import ( + checkpoint_utils, distributed_utils, metrics, options, progress_bar, tasks, utils +) from fairseq.data import iterators from fairseq.trainer import Trainer -from fairseq.meters import AverageMeter, StopwatchMeter +from fairseq.meters import StopwatchMeter def main(args, init_distributed=False): @@ -94,7 +95,7 @@ def main(args, init_distributed=False): else: valid_losses = [None] - # only use first validation wer to update the learning rate + # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint @@ -147,37 +148,24 @@ def train(args, trainer, task, epoch_itr): args, itr, epoch_itr.epoch, no_progress_bar='simple', ) - extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) - for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): + for samples in progress: if hasattr(trainer.criterion, 'set_num_updates'): trainer.criterion.set_num_updates(trainer.get_num_updates()) - log_output = trainer.train_step(samples) - if log_output is None: - continue - - # log mid-epoch stats - stats = get_training_stats(trainer) - for k, v in log_output.items(): - if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: - continue # these are already logged above - if 'loss' in k or k == 'accuracy': - extra_meters[k].update(v, log_output['sample_size']) - else: - extra_meters[k].update(v) - stats[k] = extra_meters[k].avg - progress.log(stats, tag='train', step=stats['num_updates']) - - # ignore the first mini-batch in words-per-second and updates-per-second calculation - if i == 0: - trainer.get_meter('wps').reset() - trainer.get_meter('ups').reset() - - num_updates = trainer.get_num_updates() + with metrics.aggregate('train_inner'): + log_output = trainer.train_step(samples) + num_updates = trainer.get_num_updates() + if log_output is None: + continue + + # log mid-epoch stats + stats = get_training_stats('train_inner') + progress.log(stats, tag='train', step=num_updates) + if ( not args.disable_validation and args.save_interval_updates > 0 @@ -191,42 +179,18 @@ def train(args, trainer, task, epoch_itr): break # log end-of-epoch stats - stats = get_training_stats(trainer) - for k, meter in extra_meters.items(): - stats[k] = meter.avg - progress.print(stats, tag='train', step=stats['num_updates']) - - # reset training meters - for k in [ - 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', - ]: - meter = trainer.get_meter(k) - if meter is not None: - meter.reset() - - -def get_training_stats(trainer): - stats = collections.OrderedDict() - stats['loss'] = trainer.get_meter('train_loss') - if trainer.get_meter('train_nll_loss').count > 0: - nll_loss = trainer.get_meter('train_nll_loss') - stats['nll_loss'] = nll_loss - else: - nll_loss = trainer.get_meter('train_loss') - stats['ppl'] = utils.get_perplexity(nll_loss.avg) - stats['wps'] = trainer.get_meter('wps') - stats['ups'] = trainer.get_meter('ups') - stats['wpb'] = trainer.get_meter('wpb') - stats['bsz'] = trainer.get_meter('bsz') - stats['num_updates'] = trainer.get_num_updates() - stats['lr'] = trainer.get_lr() - stats['gnorm'] = trainer.get_meter('gnorm') - stats['clip'] = trainer.get_meter('clip') - stats['oom'] = trainer.get_meter('oom') - if trainer.get_meter('loss_scale') is not None: - stats['loss_scale'] = trainer.get_meter('loss_scale') - stats['wall'] = round(trainer.get_meter('wall').elapsed_time) - stats['train_wall'] = trainer.get_meter('train_wall') + stats = get_training_stats('train') + progress.print(stats, tag='train', step=num_updates) + + # reset epoch-level meters + metrics.reset_meters('train') + + +def get_training_stats(stats_key): + stats = metrics.get_smoothed_values(stats_key) + if 'nll_loss' in stats and 'ppl' not in stats: + stats['ppl'] = utils.get_perplexity(stats['nll_loss']) + stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) return stats @@ -262,71 +226,30 @@ def validate(args, trainer, task, epoch_itr, subsets): ) # reset validation loss meters - for k in ['valid_loss', 'valid_nll_loss']: - meter = trainer.get_meter(k) - if meter is not None: - meter.reset() - extra_meters = collections.defaultdict(lambda: AverageMeter()) + metrics.reset_meters('valid') for sample in progress: - log_output = trainer.valid_step(sample) - - for k, v in log_output.items(): - if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', - 'sample_size', 'word_count', 'char_count']: - continue - if k == 'word_error': - extra_meters['wer'].update( - float(v) / log_output['word_count'] * 100, - log_output['word_count']) - elif k == 'char_error': - extra_meters['cer'].update( - float(v) / log_output['char_count'] * 100, - log_output['char_count']) - else: - extra_meters[k].update(v) + trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(trainer, args, extra_meters) - for k, meter in extra_meters.items(): - stats[k] = meter.avg + stats = get_valid_stats(args, trainer) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append( - stats[args.best_checkpoint_metric].avg - if args.best_checkpoint_metric == 'loss' - else stats[args.best_checkpoint_metric] - ) + valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses -def get_valid_stats(trainer, args, extra_meters=None): - stats = collections.OrderedDict() - stats['loss'] = trainer.get_meter('valid_loss') - if trainer.get_meter('valid_nll_loss').count > 0: - nll_loss = trainer.get_meter('valid_nll_loss') - stats['nll_loss'] = nll_loss - else: - nll_loss = stats['loss'] - stats['ppl'] = utils.get_perplexity(nll_loss.avg) +def get_valid_stats(args, trainer): + stats = metrics.get_smoothed_values('valid') + if 'nll_loss' in stats and 'ppl' not in stats: + stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['num_updates'] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, 'best'): key = 'best_{0}'.format(args.best_checkpoint_metric) best_function = max if args.maximize_best_checkpoint_metric else min - - current_metric = None - if args.best_checkpoint_metric == 'loss': - current_metric = stats['loss'].avg - elif args.best_checkpoint_metric in extra_meters: - current_metric = extra_meters[args.best_checkpoint_metric].avg - elif args.best_checkpoint_metric in stats: - current_metric = stats[args.best_checkpoint_metric] - else: - raise ValueError("best_checkpoint_metric not found in logs") - stats[key] = best_function( checkpoint_utils.save_checkpoint.best, - current_metric, + stats[args.best_checkpoint_metric], ) return stats From 458828d4e9784619430281cca4eefed324c694b9 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 17 Jan 2020 18:05:18 -0500 Subject: [PATCH 061/119] code adaptation/changes according to the commits from Jan 16 to Jan 17, 2020; move decode log files to decode result dirs --- espresso/criterions/cross_entropy_with_wer.py | 8 +- .../label_smoothed_cross_entropy_with_wer.py | 9 +- espresso/data/scp_text_dataset.py | 6 +- espresso/data/speech_dataset.py | 7 +- espresso/models/speech_fconv.py | 7 +- espresso/models/speech_lstm.py | 9 +- espresso/models/speech_transformer.py | 8 +- espresso/speech_recognize.py | 77 +++++++++++------ espresso/speech_train.py | 85 +++++++++++-------- espresso/tasks/language_modeling_for_asr.py | 10 ++- espresso/tasks/speech_recognition.py | 17 ++-- espresso/tools/compute_wer.py | 15 +++- espresso/tools/text2vocabulary.py | 32 ++++--- espresso/tools/wer.py | 9 +- examples/asr_librispeech/run.sh | 9 +- examples/asr_swbd/run.sh | 7 +- examples/asr_wsj/run.sh | 9 +- tests/espresso/test_speech_utils.py | 8 +- 18 files changed, 212 insertions(+), 120 deletions(-) diff --git a/espresso/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py index 7c48799cc..3946040ec 100644 --- a/espresso/criterions/cross_entropy_with_wer.py +++ b/espresso/criterions/cross_entropy_with_wer.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import numpy as np import torch.nn.functional as F @@ -16,6 +17,9 @@ from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder +logger = logging.getLogger(__name__) + + @register_criterion('cross_entropy_with_wer') class CrossEntropyWithWERCriterion(CrossEntropyCriterion): @@ -70,8 +74,8 @@ def forward(self, model, sample, reduce=True): dictionary.string(pred.data[i][:length]), use_unk_sym=True, bpe_symbol=self.args.remove_bpe, ) - print('| sample REF: ' + ref_one) - print('| sample PRD: ' + pred_one) + logger.info('sample REF: ' + ref_one) + logger.info('sample PRD: ' + pred_one) else: tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample) pred = tokens[:, 1:].data.cpu() # bsz x len diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py index 713591c7c..3f3ea1190 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_with_wer.py @@ -3,7 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import numpy as np + import torch from fairseq import metrics, utils @@ -15,6 +17,9 @@ from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder +logger = logging.getLogger(__name__) + + def temporal_label_smoothing_prob_mask( lprobs: torch.Tensor, # R[Batch, SeqLength, Vocab] target: torch.Tensor, # Z[Batch, SeqLength] @@ -143,8 +148,8 @@ def forward(self, model, sample, reduce=True): dictionary.string(pred.data[i][:length]), use_unk_sym=True, bpe_symbol=self.args.remove_bpe, ) - print('| sample REF: ' + ref_one) - print('| sample PRD: ' + pred_one) + logger.info('sample REF: ' + ref_one) + logger.info('sample PRD: ' + pred_one) else: tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample) pred = tokens[:, 1:].data.cpu() # bsz x len diff --git a/espresso/data/scp_text_dataset.py b/espresso/data/scp_text_dataset.py index eb068a42c..436995d6c 100644 --- a/espresso/data/scp_text_dataset.py +++ b/espresso/data/scp_text_dataset.py @@ -38,8 +38,7 @@ def read_scp(self, path): try: feat = kaldi_io.read_mat(filename) except Exception: - print('Failed to read feature matrix {}.'.format(filename)) - raise + raise Exception('failed to read feature matrix {}.'.format(filename)) assert feat is not None and isinstance(feat, np.ndarray) self.sizes.append(feat.shape[0]) self.sizes = np.array(self.sizes, dtype=np.int32) @@ -123,12 +122,11 @@ def __getitem__(self, i): i, self.start_pos_for_next_cache, ) except ValueError: - print( + raise ValueError( 'index {} not found in self.ordered_indices. Set ' 'self.ordered_prefetch to False, and/or call self.prefetch() ' 'with the full list of indices, and then try again.'.format(i) ) - raise pos_end = min( pos_start + self.cache_size, len(self.ordered_indices), ) diff --git a/espresso/data/speech_dataset.py b/espresso/data/speech_dataset.py index b517e451b..dcc76a35f 100644 --- a/espresso/data/speech_dataset.py +++ b/espresso/data/speech_dataset.py @@ -140,9 +140,10 @@ def _match_src_tgt(self): try: tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: - print('Unable to find some utt_id(s) in tgt. which is unlikely to \ - happen. Something must be wrong.') - raise + raise ValueError( + 'Unable to find some utt_id(s) in tgt. which is unlikely to happen. ' + 'Something must be wrong.' + ) self.tgt.filter_and_reorder(tgt_indices) self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids diff --git a/espresso/models/speech_fconv.py b/espresso/models/speech_fconv.py index b3bd59891..3f119317f 100644 --- a/espresso/models/speech_fconv.py +++ b/espresso/models/speech_fconv.py @@ -3,7 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -27,6 +29,9 @@ import espresso.tools.utils as speech_utils +logger = logging.getLogger(__name__) + + @register_model('speech_fconv') class SpeechFConvModel(FConvModel): """ @@ -96,7 +101,7 @@ def eval_str_nested_list_or_tuple(x, type=int): out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) - print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) + logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index b76cb7c3f..0fbb8a69c 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging + import torch import torch.nn as nn import torch.nn.functional as F @@ -31,6 +33,9 @@ import espresso.tools.utils as speech_utils +logger = logging.getLogger(__name__) + + @register_model('speech_lstm') class SpeechLSTMModel(FairseqEncoderDecoderModel): def __init__(self, encoder, decoder, pretrained_lm=None): @@ -171,7 +176,7 @@ def eval_str_nested_list_or_tuple(x, type=int): out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) - print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) + logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, @@ -226,7 +231,7 @@ def eval_str_nested_list_or_tuple(x, type=int): ) pretrained_lm = None if args.pretrained_lm_checkpoint: - print('| loading pretrained LM from {}'.format(args.pretrained_lm_checkpoint)) + logger.info('loading pretrained LM from {}'.format(args.pretrained_lm_checkpoint)) pretrained_lm = checkpoint_utils.load_model_ensemble( args.pretrained_lm_checkpoint, task=task)[0][0] pretrained_lm.make_generation_fast_() diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index bba452fc8..75419585d 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging + import torch import torch.nn as nn import torch.nn.functional as F @@ -25,10 +27,14 @@ from espresso.models.speech_lstm import ConvBNReLU import espresso.tools.utils as speech_utils + DEFAULT_MAX_SOURCE_POSITIONS = 9999 DEFAULT_MAX_TARGET_POSITIONS = 999 +logger = logging.getLogger(__name__) + + @register_model('speech_transformer') class SpeechTransformerModel(TransformerModel): """ @@ -112,7 +118,7 @@ def eval_str_nested_list_or_tuple(x, type=int): out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) - print('| input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) + logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 27decbb4c..48fdac0c7 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -8,8 +8,10 @@ Recognize pre-processed speech with a trained model. """ +import logging import math import os +import sys import torch @@ -28,11 +30,32 @@ def main(args): assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' + if args.results_path is not None: + os.makedirs(args.results_path, exist_ok=True) + output_path = os.path.join(args.results_path, 'decode.log') + with open(output_path, 'w', buffering=1) as h: + return _main(args, h) + return _main(args, sys.stdout) + + +def _main(args, output_file): + logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=output_file, + ) + logger = logging.getLogger('espresso.speech_recognize') + if output_file is not sys.stdout: # also print to stdout + logger.addHandler(logging.StreamHandler(sys.stdout)) + + print_options_meaning_changes(args, logger) + utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 - print(args) + logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu @@ -44,7 +67,7 @@ def main(args): dictionary = task.target_dictionary # Load ensemble - print('| loading model(s) from {}'.format(args.path)) + logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(':'), arg_overrides=eval(args.model_overrides), @@ -61,19 +84,19 @@ def main(args): open_vocab=not args.disable_open_vocab, ) del models[i] - print('| LM fusion with Multi-level LM') + logger.info('LM fusion with Multi-level LM') else: models[i] = TensorizedLookaheadLanguageModel( m, dictionary, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) - print('| LM fusion with Look-ahead Word LM') + logger.info('LM fusion with Look-ahead Word LM') # assume subword LM comes after E2E models elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel): - print('| LM fusion with Subword LM') + logger.info('LM fusion with Subword LM') if args.lm_weight != 0.0: - print('| using LM fusion with lm-weight={:.2f}'.format(args.lm_weight)) + logger.info('using LM fusion with lm-weight={:.2f}'.format(args.lm_weight)) # Optimize ensemble for generation for model in models: @@ -105,8 +128,9 @@ def main(args): # Initialize generator if args.match_source_len: - print('| The option match_source_len is not applicable to ' - 'speech recognition. Ignoring it.') + logger.warning( + 'The option match_source_len is not applicable to speech recognition. Ignoring it.' + ) gen_timer = StopwatchMeter() generator = task.build_generator(args) @@ -150,7 +174,7 @@ def main(args): target_sent = dictionary.tokens_to_sentence( target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, ) - print('T-{}\t{}'.format(utt_id, target_sent)) + print('T-{}\t{}'.format(utt_id, target_sent), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): @@ -160,7 +184,7 @@ def main(args): if not args.quiet: score = hypo['score'] / math.log(2) # convert to base 2 - print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, score)) + print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, score), file=output_file) # Score and obtain attention only the top hypothesis if j == 0: @@ -179,64 +203,62 @@ def main(args): t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] - print('| NOTE: hypothesis and token scores are output in base 2') - print('| Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( + logger.info('NOTE: hypothesis and token scores are output in base 2') + logger.info('Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.print_alignment: - print('| Saved attention plots in ' + save_dir) + logger.info('Saved attention plots in ' + save_dir) if has_target: assert args.test_text_files is not None scorer.add_ordered_utt_list(*args.test_text_files) - os.makedirs(args.results_path, exist_ok=True) - fn = 'decoded_char_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_char_results()) - print('| Decoded char results saved as ' + f.name) + logger.info('Decoded char results saved as ' + f.name) fn = 'decoded_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_results()) - print('| Decoded results saved as ' + f.name) + logger.info('Decoded results saved as ' + f.name) if has_target: - header = ' Recognize {} with beam={}: '.format(args.gen_subset, args.beam) + header = 'Recognize {} with beam={}: '.format(args.gen_subset, args.beam) fn = 'wer' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.wer())) - print('|' + header + res) + logger.info(header + res) f.write(res + '\n') - print('| WER saved in ' + f.name) + logger.info('WER saved in ' + f.name) fn = 'cer' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( *(scorer.cer())) - print('|' + ' ' * len(header) + res) + logger.info(' ' * len(header) + res) f.write(res + '\n') - print('| CER saved in ' + f.name) + logger.info('CER saved in ' + f.name) fn = 'aligned_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: f.write(scorer.print_aligned_results()) - print('| Aligned results saved as ' + f.name) + logger.info('Aligned results saved as ' + f.name) return scorer -def print_options_meaning_changes(args): +def print_options_meaning_changes(args, logger): """Options that have different meanings than those in the translation task are explained here. """ - print('| --max-tokens is the maximum number of input frames in a batch') + logger.info('--max-tokens is the maximum number of input frames in a batch') if args.print_alignment: - print('| --print-alignment has been set to plot attentions') + logger.info('--print-alignment has been set to plot attentions') def cli_main(): - parser = options.get_generation_parser(default_task='speech_recognition') + parser = options.get_generation_parser(default_task='speech_recognition_espresso') parser.add_argument('--eos-factor', default=None, type=float, metavar='F', help='only consider emitting EOS if its score is no less ' 'than the specified factor of the best candidate score') @@ -253,7 +275,6 @@ def cli_main(): 'pretrained external LM') args = options.parse_args_and_arch(parser) assert args.results_path is not None, 'please specify --results-path' - print_options_meaning_changes(args) main(args) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 5b0a72da4..332ad8b2e 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -8,8 +8,10 @@ Train a new model on one or across multiple GPUs. """ +import logging import math import random +import sys import numpy as np import torch @@ -22,6 +24,15 @@ from fairseq.meters import StopwatchMeter +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=sys.stdout, +) +logger = logging.getLogger('espresso.speech_train') + + def main(args, init_distributed=False): utils.import_user_module(args) @@ -40,7 +51,7 @@ def main(args, init_distributed=False): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args - print(args) + logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) @@ -52,17 +63,17 @@ def main(args, init_distributed=False): # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) - print(model) - print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) - print('| num. model params: {} (num. trained: {})'.format( + logger.info(model) + logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) + logger.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) - print('| training on {} GPUs'.format(args.distributed_world_size)) - print('| max input frames per GPU = {} and max sentences per GPU = {}'.format( + logger.info('training on {} GPUs'.format(args.distributed_world_size)) + logger.info('max input frames per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) @@ -104,14 +115,14 @@ def main(args, init_distributed=False): # early stop if should_stop_early(args, valid_losses[0]): - print('| Early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) + logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) break reload_dataset = len(args.train_feat_files) > 1 # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() - print('| done training in {:.1f} seconds'.format(train_meter.sum)) + logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) def should_stop_early(args, valid_loss): @@ -148,46 +159,48 @@ def train(args, trainer, task, epoch_itr): args, itr, epoch_itr.epoch, no_progress_bar='simple', ) + # task specific setup per epoch + task.begin_epoch(epoch_itr.epoch, trainer.get_model()) + valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) - for samples in progress: - if hasattr(trainer.criterion, 'set_num_updates'): - trainer.criterion.set_num_updates(trainer.get_num_updates()) + with metrics.aggregate() as agg: + for samples in progress: + if hasattr(trainer.criterion, 'set_num_updates'): + trainer.criterion.set_num_updates(trainer.get_num_updates()) - with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats - stats = get_training_stats('train_inner') + stats = get_training_stats(agg.get_smoothed_values()) progress.log(stats, tag='train', step=num_updates) - if ( - not args.disable_validation - and args.save_interval_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates > 0 - ): - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + if ( + not args.disable_validation + and args.save_interval_updates > 0 + and num_updates % args.save_interval_updates == 0 + and num_updates > 0 + ): + valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - if num_updates >= max_update: - break + if num_updates >= max_update: + break # log end-of-epoch stats - stats = get_training_stats('train') + stats = get_training_stats(agg.get_smoothed_values()) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') -def get_training_stats(stats_key): - stats = metrics.get_smoothed_values(stats_key) +def get_training_stats(stats): if 'nll_loss' in stats and 'ppl' not in stats: stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) @@ -225,22 +238,22 @@ def validate(args, trainer, task, epoch_itr, subsets): no_progress_bar='simple' ) - # reset validation loss meters + # reset validation meters metrics.reset_meters('valid') - for sample in progress: - trainer.valid_step(sample) + with metrics.aggregate() as agg: + for sample in progress: + trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(args, trainer) + stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses -def get_valid_stats(args, trainer): - stats = metrics.get_smoothed_values('valid') +def get_valid_stats(args, trainer, stats): if 'nll_loss' in stats and 'ppl' not in stats: stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['num_updates'] = trainer.get_num_updates() @@ -265,16 +278,16 @@ def print_options_meaning_changes(args): """Options that have different meanings than those in the translation task are explained here. """ - print('| --max-tokens is the maximum number of input frames in a batch') + logger.info('--max-tokens is the maximum number of input frames in a batch') -def cli_main(): +def cli_main(modify_parser=None): parser = options.get_training_parser() parser.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, help='remove BPE tokens before scoring ' '(can be set to sentencepiece). Being used for monitoring ' 'and validation') - args = options.parse_args_and_arch(parser) + args = options.parse_args_and_arch(parser, modify_parser=modify_parser) print_options_meaning_changes(args) if args.distributed_init_method is None: @@ -299,7 +312,7 @@ def cli_main(): args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_rank = None # set based on device id if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': - print('| NOTE: you may get faster training with: --ddp-backend=no_c10d') + logger.info('NOTE: you may get faster training with: --ddp-backend=no_c10d') torch.multiprocessing.spawn( fn=distributed_main, args=(args, ), diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 444ed66af..433f224b0 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -3,10 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch - +import logging import os +import torch + from fairseq import tokenizer from fairseq.data import TruncatedDictionary from fairseq.tasks import register_task @@ -15,6 +16,9 @@ from espresso.data import AsrDictionary +logger = logging.getLogger(__name__) + + @register_task("language_modeling_for_asr") class LanguageModelingForASRTask(LanguageModelingTask): """ @@ -106,7 +110,7 @@ def setup_task(cls, args, **kwargs): dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \ else args.dict dictionary = AsrDictionary.load(dict_path) - print("| dictionary: {} types".format(len(dictionary))) + logger.info("dictionary: {} types".format(len(dictionary))) output_dictionary = dictionary if args.output_dictionary_size >= 0: output_dictionary = TruncatedDictionary( diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 6bd9d360b..e5cc80cea 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import os import torch @@ -20,6 +21,9 @@ ) +logger = logging.getLogger(__name__) + + @register_task('speech_recognition_espresso') class SpeechRecognitionEspressoTask(FairseqTask): """ @@ -134,10 +138,10 @@ def setup_task(cls, args, **kwargs): if args.dict is None and args.train_text_files is not None else args.dict assert dict_path is not None, 'Please specify --dict' dictionary = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) - print('| dictionary: {} types'.format(len(dictionary))) + logger.info('dictionary: {} types'.format(len(dictionary))) if args.word_dict is not None: word_dict = cls.load_dictionary(args.word_dict) - print('| word dictionary: {} types'.format(len(word_dict))) + logger.info('word dictionary: {} types'.format(len(word_dict))) return cls(args, dictionary, word_dict) else: @@ -178,10 +182,10 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): assert ScpCachedDataset.exists(feat), feat + ' does not exists' assert text is None or AsrTextDataset.exists(text), text + ' does not exists' src_datasets.append(ScpCachedDataset(feat, ordered_prefetch=True)) - print('| {} {} examples'.format(feat, len(src_datasets[-1]))) + logger.info('{} {} examples'.format(feat, len(src_datasets[-1]))) if text is not None: tgt_datasets.append(AsrTextDataset(text, self.dictionary)) - print('| {} {} examples'.format(text, len(tgt_datasets[-1]))) + logger.info('{} {} examples'.format(text, len(tgt_datasets[-1]))) if not combine: break @@ -225,8 +229,9 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): def build_generator(self, args): if args.score_reference: args.score_reference = False - print('| --score-reference is not applicable to speech recognition,' - ' ignoring it.') + logger.warning( + '--score-reference is not applicable to speech recognition, ignoring it.' + ) from fairseq.sequence_generator import SequenceGenerator # Choose search strategy. Defaults to Beam Search. diff --git a/espresso/tools/compute_wer.py b/espresso/tools/compute_wer.py index 4b6fae77c..8555e0995 100755 --- a/espresso/tools/compute_wer.py +++ b/espresso/tools/compute_wer.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import logging import re import sys from collections import Counter @@ -12,6 +13,15 @@ from espresso.tools.utils import edit_distance +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=sys.stderr, +) +logger = logging.getLogger('espresso.tools.compute_wer') + + def get_parser(): parser = argparse.ArgumentParser( description='Compute WER from text') @@ -53,10 +63,7 @@ def main(args): assert m is not None word_filters.append([m.group(1), m.group(2)]) else: - print( - 'Unsupported pattern: "{}", ignored'.format(line), - file=sys.stderr, - ) + logger.warning('Unsupported pattern: "{}". Ignoring it.'.format(line)) refs = {} with open(args.ref_text, 'r', encoding='utf-8') as f: diff --git a/espresso/tools/text2vocabulary.py b/espresso/tools/text2vocabulary.py index 50681b967..047d43957 100755 --- a/espresso/tools/text2vocabulary.py +++ b/espresso/tools/text2vocabulary.py @@ -5,11 +5,21 @@ # LICENSE file in the root directory of this source tree. import argparse +import logging import os import sys from collections import Counter +logging.basicConfig( + format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + stream=sys.stderr, +) +logger = logging.getLogger('espresso.tools.text2vocabulary') + + def get_parser(): parser = argparse.ArgumentParser( description='Create a vocabulary from text files') @@ -65,7 +75,7 @@ def main(args): most_common = most_common[:cutoff_point] vocab_set = set(list(zip(*most_common))[0]) else: - print('using the provided vocabulary:', file=sys.stderr) + logger.info('using the provided vocabulary:') with open(args.vocab, 'r', encoding='utf-8') as f: vocab_set = set([line.rstrip().split()[0] for line in f]) most_common = [] @@ -78,11 +88,11 @@ def main(args): print('{} {:d}'.format(w, c)) oov_rate = 1. - float(invocab_count) / total_count - print('training set:', file=sys.stderr) - print(' total #tokens={:d}'.format(total_count), file=sys.stderr) - print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) + logger.info('training set:') + logger.info(' total #tokens={:d}'.format(total_count)) + logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) if args.vocab is None: - print(' cutoff frequency={:d}'.format(cutoff_freq), file=sys.stderr) + logger.info(' cutoff frequency={:d}'.format(cutoff_freq)) if args.valid_text is not None: total_count = 0 @@ -94,9 +104,9 @@ def main(args): total_count += len(tokens) invocab_count += len([tok for tok in tokens if tok in vocab_set]) oov_rate = 1. - float(invocab_count) / total_count - print('validation set:', file=sys.stderr) - print(' total #tokens={:d}'.format(total_count), file=sys.stderr) - print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) + logger.info('validation set:') + logger.info(' total #tokens={:d}'.format(total_count)) + logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) if args.test_text is not None: for k, path in enumerate(args.test_text.split(os.pathsep)): @@ -109,9 +119,9 @@ def main(args): total_count += len(tokens) invocab_count += len([tok for tok in tokens if tok in vocab_set]) oov_rate = 1. - float(invocab_count) / total_count - print('test set{}:'.format(k) if k > 0 else 'test set:', file=sys.stderr) - print(' total #tokens={:d}'.format(total_count), file=sys.stderr) - print(' OOV rate={:.2f}%'.format(oov_rate * 100), file=sys.stderr) + logger.info('test set{}:'.format(k) if k > 0 else 'test set:') + logger.info(' total #tokens={:d}'.format(total_count)) + logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) if __name__ == '__main__': diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index f3c8c25af..f8b7618b5 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import re import sys @@ -11,6 +12,9 @@ import espresso.tools.utils as speech_utils +logger = logging.getLogger(__name__) + + class Scorer(object): def __init__(self, dictionary, wer_output_filter=None): self.dictionary = dictionary @@ -42,10 +46,7 @@ def parse_wer_output_filter(self, wer_output_filter): assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: - print( - 'Unsupported pattern: "{}", ignored'.format(line), - file=sys.stderr, - ) + logger.warning('Unsupported pattern: "{}". Ignoring it'.format(line)) def add_prediction(self, utt_id, pred, bpe_symbol=None): if not isinstance(utt_id, str): diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 084a418ab..ce02fd404 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -236,19 +236,20 @@ if [ ${stage} -le 8 ]; then for dataset in $test_set; do feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp text=data/$dataset/token_text + decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --max-sentences 24 \ --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ - --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ - 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + --results-path $decode_dir $opts + echo "log saved in ${decode_dir}/decode.log" if $kaldi_scoring; then echo "verify WER by scoring with Kaldi..." - local/score.sh data/$dataset $dir/decode_$dataset${decode_affix:+_${decode_affix}} - cat $dir/decode_$dataset${decode_affix:+_${decode_affix}}/scoring_kaldi/wer + local/score.sh data/$dataset $decode_dir + cat ${decode_dir}/scoring_kaldi/wer fi done fi diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 4565d465e..10b7e8371 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -277,18 +277,19 @@ if [ $stage -le 7 ]; then fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $test_set; do + feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} # only score train_dev with built-in scorer text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ - --num-shards 1 --shard-id 0 --test-feat-files ${dumpdir}/$dataset/delta${do_delta}/feats.scp $text_opt \ + --num-shards 1 --shard-id 0 --test-feat-files $feat $text_opt \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ - --results-path $decode_dir $opts \ - 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + --results-path $decode_dir $opts + echo "log saved in ${decode_dir}/decode.log" echo "Scoring with kaldi..." local/score.sh data/$dataset $decode_dir if [ "$dataset" == "train_dev" ]; then diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 1b2d94ccb..47ab4228a 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -306,19 +306,20 @@ if [ ${stage} -le 9 ]; then feat=$test_feat_dir/feats.scp fi text=data/$dataset/token_text + decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --max-sentences 32 \ --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ - --results-path $dir/decode_$dataset${decode_affix:+_${decode_affix}} $opts \ - --print-alignment 2>&1 | tee $dir/logs/decode_$dataset${decode_affix:+_${decode_affix}}.log + --results-path $decode_dir $opts --print-alignment + echo "log saved in ${decode_dir}/decode.log" if $kaldi_scoring; then echo "verify WER by scoring with Kaldi..." - local/score.sh data/$dataset $dir/decode_$dataset${decode_affix:+_${decode_affix}} - cat $dir/decode_$dataset${decode_affix:+_${decode_affix}}/scoring_kaldi/wer + local/score.sh data/$dataset $decode_dir + cat ${decode_dir}/scoring_kaldi/wer fi done fi diff --git a/tests/espresso/test_speech_utils.py b/tests/espresso/test_speech_utils.py index bbe991e48..c3eb10ce9 100644 --- a/tests/espresso/test_speech_utils.py +++ b/tests/espresso/test_speech_utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import unittest import string import numpy as np @@ -15,6 +16,9 @@ import espresso.tools.utils as utils +logger = logging.getLogger(__name__) + + class TestSpeechUtils(unittest.TestCase): @staticmethod @@ -70,8 +74,8 @@ def setUp(self): def test_speech_tokenizer(self): for i, sent in enumerate(self.text): - print('test sentence {}:'.format(i)) - print(sent) + logger.info('test sentence {}:'.format(i)) + logger.info(sent) tokens = utils.tokenize( sent, space=self.dictionary.space_word, non_lang_syms=self.non_lang_syms, From 293a068a4ed2a02cba2a92f90b9c48990201d42d Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 19 Jan 2020 04:42:00 -0500 Subject: [PATCH 062/119] move WER validation with greedy decoding code from criterion to task; valid loss is now based on teacher enforcing instead of greedy decoding; rename {,label_smoothed_}cross_entropy_with_wer.py to {,label_smoothed_}cross_entropy_v2.py --- espresso/criterions/cross_entropy_v2.py | 97 +++++++++++++ espresso/criterions/cross_entropy_with_wer.py | 133 ------------------ ....py => label_smoothed_cross_entropy_v2.py} | 127 +++++++---------- espresso/tasks/speech_recognition.py | 49 ++++++- examples/asr_librispeech/run.sh | 2 +- examples/asr_swbd/run.sh | 2 +- examples/asr_wsj/run.sh | 2 +- 7 files changed, 195 insertions(+), 217 deletions(-) create mode 100644 espresso/criterions/cross_entropy_v2.py delete mode 100644 espresso/criterions/cross_entropy_with_wer.py rename espresso/criterions/{label_smoothed_cross_entropy_with_wer.py => label_smoothed_cross_entropy_v2.py} (58%) diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py new file mode 100644 index 000000000..5c97873b2 --- /dev/null +++ b/espresso/criterions/cross_entropy_v2.py @@ -0,0 +1,97 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import numpy as np + +import torch.nn.functional as F + +from fairseq import utils +from fairseq.criterions import register_criterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.data import data_utils + + +logger = logging.getLogger(__name__) + + +@register_criterion('cross_entropy_v2') +class CrossEntropyV2Criterion(CrossEntropyCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + + self.dictionary = task.target_dictionary + self.num_updates = -1 + self.epoch = 0 + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--print-training-sample-interval', type=int, + metavar='N', dest='print_interval', default=500, + help='print a training sample (reference + ' + 'prediction) every this number of updates') + # fmt: on + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample; periodically print out + randomly sampled predictions from the training set. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample['net_input'], epoch=self.epoch) + loss, _, lprobs = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + logging_output = { + 'loss': utils.item(loss.data) if reduce else loss.data, + 'ntokens': sample['ntokens'], + 'nsentences': sample['target'].size(0), + 'sample_size': sample_size, + } + + if ( + model.training and self.num_updates // self.args.print_interval > + (self.num_updates - 1) // self.args.print_interval + ): # print a randomly sampled result every print_interval updates + target = model.get_targets(sample, net_output) + pred = lprobs.argmax(-1).cpu() # bsz x len + assert pred.size() == target.size() + with data_utils.numpy_seed(self.num_updates): + i = np.random.randint(0, len(sample['id'])) + ref_tokens = sample['target_raw_text'][i] + length = utils.strip_pad(target.data[i], self.padding_idx).size(0) + ref_one = self.dictionary.tokens_to_sentence( + ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, + ) + pred_one = self.dictionary.tokens_to_sentence( + self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, + bpe_symbol=self.args.remove_bpe, + ) + logger.info('sample REF: ' + ref_one) + logger.info('sample PRD: ' + pred_one) + + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + loss = F.nll_loss( + lprobs.view(-1, lprobs.size(-1)), + target.view(-1), + ignore_index=self.padding_idx, + reduction='sum' if reduce else 'none', + ) + return loss, loss, lprobs + + def set_num_updates(self, num_updates): + self.num_updates = num_updates + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/espresso/criterions/cross_entropy_with_wer.py b/espresso/criterions/cross_entropy_with_wer.py deleted file mode 100644 index 3946040ec..000000000 --- a/espresso/criterions/cross_entropy_with_wer.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Yiming Wang -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import numpy as np - -import torch.nn.functional as F - -from fairseq import metrics, utils -from fairseq.criterions import register_criterion -from fairseq.criterions.cross_entropy import CrossEntropyCriterion -from fairseq.data import data_utils - -from espresso.tools import wer -from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder - - -logger = logging.getLogger(__name__) - - -@register_criterion('cross_entropy_with_wer') -class CrossEntropyWithWERCriterion(CrossEntropyCriterion): - - def __init__(self, args, task): - super().__init__(args, task) - - dictionary = task.target_dictionary - self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) - self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) - self.num_updates = -1 - self.epoch = 0 - - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--print-training-sample-interval', type=int, - metavar='N', dest='print_interval', default=500, - help='print a training sample (reference + ' - 'prediction) every this number of updates') - # fmt: on - - def forward(self, model, sample, reduce=True): - """Compute the loss for the given sample; periodically print out - randomly sampled predictions if model is in training mode, otherwise - aggregate word error stats for validation. - - Returns a tuple with three elements: - 1) the loss - 2) the sample size, which is used as the denominator for the gradient - 3) logging outputs to display while training - """ - dictionary = self.scorer.dictionary - if model.training: - net_output = model(**sample['net_input'], epoch=self.epoch) - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) - if ( - self.num_updates // self.args.print_interval > - (self.num_updates - 1) // self.args.print_interval - ): # print a randomly sampled result every print_interval updates - pred = lprobs.argmax(-1).cpu() # bsz x len - assert pred.size() == target.size() - with data_utils.numpy_seed(self.num_updates): - i = np.random.randint(0, len(sample['id'])) - ref_tokens = sample['target_raw_text'][i] - length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - ref_one = dictionary.tokens_to_sentence( - ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, - ) - pred_one = dictionary.tokens_to_sentence( - dictionary.string(pred.data[i][:length]), use_unk_sym=True, - bpe_symbol=self.args.remove_bpe, - ) - logger.info('sample REF: ' + ref_one) - logger.info('sample PRD: ' + pred_one) - else: - tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample) - pred = tokens[:, 1:].data.cpu() # bsz x len - target = sample['target'] - # compute word error stats - assert pred.size(0) == target.size(0) - self.scorer.reset() - for i in range(target.size(0)): - utt_id = sample['utt_id'][i] - ref_tokens = sample['target_raw_text'][i] - pred_tokens = dictionary.string(pred.data[i]) - self.scorer.add_evaluation( - utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, - ) - - lprobs = lprobs.view(-1, lprobs.size(-1)) - loss = F.nll_loss( - lprobs, - target.view(-1), - ignore_index=self.padding_idx, - reduction='sum' if reduce else 'none', - ) - sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] - logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['target'].size(0), - 'sample_size': sample_size, - } - if not model.training: # do not compute word error in training mode - logging_output['word_error'] = self.scorer.tot_word_error() - logging_output['word_count'] = self.scorer.tot_word_count() - logging_output['char_error'] = self.scorer.tot_char_error() - logging_output['char_count'] = self.scorer.tot_char_count() - return loss, sample_size, logging_output - - @staticmethod - def reduce_metrics(logging_outputs) -> None: - """Aggregate logging outputs from data parallel training.""" - CrossEntropyCriterion.reduce_metrics(logging_outputs) - - word_error = sum(log.get('word_error', 0) for log in logging_outputs) - word_count = sum(log.get('word_count', 0) for log in logging_outputs) - char_error = sum(log.get('char_error', 0) for log in logging_outputs) - char_count = sum(log.get('char_count', 0) for log in logging_outputs) - if word_count > 0: # model.training == False - metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) - if char_count > 0: # model.training == False - metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) - - def set_num_updates(self, num_updates): - self.num_updates = num_updates - - def set_epoch(self, epoch): - self.epoch = epoch diff --git a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py similarity index 58% rename from espresso/criterions/label_smoothed_cross_entropy_with_wer.py rename to espresso/criterions/label_smoothed_cross_entropy_v2.py index 3f3ea1190..9b4b797b3 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_with_wer.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -8,14 +8,11 @@ import torch -from fairseq import metrics, utils +from fairseq import utils from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from fairseq.data import data_utils -from espresso.tools import wer -from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder - logger = logging.getLogger(__name__) @@ -80,22 +77,20 @@ def label_smoothed_nll_loss( return loss, nll_loss -@register_criterion('label_smoothed_cross_entropy_with_wer') -class LabelSmoothedCrossEntropyWithWERCriterion(LabelSmoothedCrossEntropyCriterion): +@register_criterion('label_smoothed_cross_entropy_v2') +class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): def __init__(self, args, task): super().__init__(args, task) - dictionary = task.target_dictionary - self.scorer = wer.Scorer(dictionary, wer_output_filter=task.args.wer_output_filter) - self.decoder_for_validation = SimpleGreedyDecoder(dictionary, for_validation=True) + self.dictionary = task.target_dictionary self.num_updates = -1 self.epoch = 0 self.unigram_tensor = None if args.smoothing_type == 'unigram': - self.unigram_tensor = torch.cuda.FloatTensor(dictionary.count).unsqueeze(-1) \ + self.unigram_tensor = torch.cuda.FloatTensor(self.dictionary.count).unsqueeze(-1) \ if torch.cuda.is_available() and not args.cpu \ - else torch.FloatTensor(dictionary.count).unsqueeze(-1) + else torch.FloatTensor(self.dictionary.count).unsqueeze(-1) self.unigram_tensor += args.unigram_pseudo_count # for further backoff self.unigram_tensor.div_(self.unigram_tensor.sum()) @@ -118,63 +113,16 @@ def add_args(parser): def forward(self, model, sample, reduce=True): """Compute the loss for the given sample; periodically print out - randomly sampled predictions if model is in training mode, otherwise - aggregate word error stats for validation. + randomly sampled predictions from the training set. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - dictionary = self.scorer.dictionary - if model.training: - net_output = model(**sample['net_input'], epoch=self.epoch) - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) - if ( - self.num_updates // self.args.print_interval > - (self.num_updates - 1) // self.args.print_interval - ): # print a randomly sampled result every print_interval updates - pred = lprobs.argmax(-1).cpu() # bsz x len - assert pred.size() == target.size() - with data_utils.numpy_seed(self.num_updates): - i = np.random.randint(0, len(sample['id'])) - ref_tokens = sample['target_raw_text'][i] - length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - ref_one = dictionary.tokens_to_sentence( - ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, - ) - pred_one = dictionary.tokens_to_sentence( - dictionary.string(pred.data[i][:length]), use_unk_sym=True, - bpe_symbol=self.args.remove_bpe, - ) - logger.info('sample REF: ' + ref_one) - logger.info('sample PRD: ' + pred_one) - else: - tokens, lprobs, _ = self.decoder_for_validation.decode([model], sample) - pred = tokens[:, 1:].data.cpu() # bsz x len - target = sample['target'] - # compute word error stats - assert pred.size(0) == target.size(0) - self.scorer.reset() - for i in range(target.size(0)): - utt_id = sample['utt_id'][i] - ref_tokens = sample['target_raw_text'][i] - pred_tokens = dictionary.string(pred.data[i]) - self.scorer.add_evaluation( - utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, - ) - - prob_mask = temporal_label_smoothing_prob_mask( - lprobs, target, padding_index=self.padding_idx, - ) if self.args.smoothing_type == 'temporal' else None - - lprobs = lprobs.view(-1, lprobs.size(-1)) - target = target.view(-1, 1) - loss, nll_loss = label_smoothed_nll_loss( - lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, - smoothing_type=self.args.smoothing_type, prob_mask=prob_mask, - unigram_tensor=self.unigram_tensor, + net_output = model(**sample['net_input'], epoch=self.epoch) + loss, nll_loss, lprobs = self.compute_loss( + model, net_output, sample, reduce=reduce, smoothing_type=self.args.smoothing_type ) sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { @@ -184,26 +132,45 @@ def forward(self, model, sample, reduce=True): 'nsentences': sample['target'].size(0), 'sample_size': sample_size, } - if not model.training: # do not compute word error in training mode - logging_output['word_error'] = self.scorer.tot_word_error() - logging_output['word_count'] = self.scorer.tot_word_count() - logging_output['char_error'] = self.scorer.tot_char_error() - logging_output['char_count'] = self.scorer.tot_char_count() + + if ( + model.training and self.num_updates // self.args.print_interval > + (self.num_updates - 1) // self.args.print_interval + ): # print a randomly sampled result every print_interval updates + target = model.get_targets(sample, net_output) + pred = lprobs.argmax(-1).cpu() # bsz x len + assert pred.size() == target.size() + with data_utils.numpy_seed(self.num_updates): + i = np.random.randint(0, len(sample['id'])) + ref_tokens = sample['target_raw_text'][i] + length = utils.strip_pad(target.data[i], self.padding_idx).size(0) + ref_one = self.dictionary.tokens_to_sentence( + ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, + ) + pred_one = self.dictionary.tokens_to_sentence( + self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, + bpe_symbol=self.args.remove_bpe, + ) + logger.info('sample REF: ' + ref_one) + logger.info('sample PRD: ' + pred_one) + return loss, sample_size, logging_output - @staticmethod - def reduce_metrics(logging_outputs) -> None: - """Aggregate logging outputs from data parallel training.""" - LabelSmoothedCrossEntropyCriterion.reduce_metrics(logging_outputs) - - word_error = sum(log.get('word_error', 0) for log in logging_outputs) - word_count = sum(log.get('word_count', 0) for log in logging_outputs) - char_error = sum(log.get('char_error', 0) for log in logging_outputs) - char_count = sum(log.get('char_count', 0) for log in logging_outputs) - if word_count > 0: # model.training == False - metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) - if char_count > 0: # model.training == False - metrics.log_scalar('cer', float(char_error) / char_count * 100, char_count, round=4) + def compute_loss( + self, model, net_output, sample, reduce=True, smoothing_type='uniform' + ): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + prob_mask = temporal_label_smoothing_prob_mask( + lprobs, target, padding_index=self.padding_idx, + ) if smoothing_type == 'temporal' else None + loss, nll_loss = label_smoothed_nll_loss( + lprobs.view(-1, lprobs.size(-1)), target.view(-1, 1), self.eps, + ignore_index=self.padding_idx, reduce=reduce, + smoothing_type=smoothing_type, prob_mask=prob_mask, + unigram_tensor=self.unigram_tensor, + ) + return loss, nll_loss, lprobs def set_num_updates(self, num_updates): self.num_updates = num_updates diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index e5cc80cea..227569703 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -8,7 +8,7 @@ import torch -from fairseq import options, search +from fairseq import metrics, options, search from fairseq.data import ConcatDataset from fairseq.tasks import FairseqTask, register_task @@ -294,12 +294,37 @@ def build_generator(self, args): def build_dataset_for_inference(self, src_tokens, src_lengths): return SpeechDataset(src_tokens, src_lengths) + def build_model(self, args): + # build the greedy decoder for validation with WER + from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder + self.decoder_for_validation = SimpleGreedyDecoder(self.target_dictionary, for_validation=True) + return super().build_model(args) + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + ( + logging_output['word_error'], logging_output['word_count'], + logging_output['char_error'], logging_output['char_count'], + ) = self._inference_with_wer(self.decoder_for_validation, sample, model) + return loss, sample_size, logging_output + def inference_step(self, generator, models, sample, prefix_tokens=None, lm_weight=0.0): with torch.no_grad(): return generator.generate( models, sample, prefix_tokens=prefix_tokens, lm_weight=lm_weight, ) + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + word_error = sum(log.get('word_error', 0) for log in logging_outputs) + word_count = sum(log.get('word_count', 0) for log in logging_outputs) + char_error = sum(log.get('char_error', 0) for log in logging_outputs) + char_count = sum(log.get('char_count', 0) for log in logging_outputs) + if word_count > 0: + metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) + if char_count > 0: + metrics.log_scalar('cer', float(char_error) / char_count * 100, char_count, round=4) + def max_positions(self): """Return the max sentence length allowed by the task.""" return (self.args.max_source_positions, self.args.max_target_positions) @@ -313,3 +338,25 @@ def target_dictionary(self): def word_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" return self.word_dict + + def _inference_with_wer(self, decoder, sample, model): + from espresso.tools import wer + + scorer = wer.Scorer(self.target_dictionary, wer_output_filter=self.args.wer_output_filter) + tokens, lprobs, _ = decoder.decode([model], sample) + pred = tokens[:, 1:].data.cpu() # bsz x len + target = sample['target'] + assert pred.size(0) == target.size(0) + # compute word error stats + scorer.reset() + for i in range(target.size(0)): + utt_id = sample['utt_id'][i] + ref_tokens = sample['target_raw_text'][i] + pred_tokens = self.target_dictionary.string(pred.data[i]) + scorer.add_evaluation( + utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, + ) + return ( + scorer.tot_word_error(), scorer.tot_word_count(), + scorer.tot_char_error(), scorer.tot_char_count(), + ) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index ce02fd404..59a3fc131 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -214,7 +214,7 @@ if [ ${stage} -le 7 ]; then --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ - --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_with_wer \ + --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ --train-feat-files $train_feat --train-text-files $train_token_text \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 10b7e8371..09a4650fb 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -255,7 +255,7 @@ if [ $stage -le 6 ]; then --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ - --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_with_wer \ + --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6 \ --train-feat-files $train_feat --train-text-files $train_token_text \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 47ab4228a..a50f81a8b 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -273,7 +273,7 @@ if [ ${stage} -le 8 ]; then --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 11 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ - --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_with_wer \ + --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.05 --smoothing-type temporal \ --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6 \ --train-feat-files $train_feat --train-text-files $train_token_text \ From 80f564d68266417d43a7c0aa40320c9a2bf4fac7 Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 20 Jan 2020 20:40:57 -0500 Subject: [PATCH 063/119] code adaptation/changes according to the commits on Jan 20, 2020; cosmetic changes for lookahead LM --- espresso/models/external_language_model.py | 23 ++++---- .../tensorized_lookahead_language_model.py | 29 +++++---- espresso/speech_recognize.py | 2 +- espresso/speech_train.py | 59 +++++++++---------- espresso/tasks/language_modeling_for_asr.py | 2 +- examples/asr_swbd/run.sh | 3 +- 6 files changed, 56 insertions(+), 62 deletions(-) diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index d3661b5f3..505a2ca48 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -101,7 +101,8 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) cached_state = utils.get_incremental_state( - self.lm_decoder, incremental_state, 'cached_state') + self.lm_decoder, incremental_state, 'cached_state', + ) if cached_state is None: # it is the first time step assert (prev_output_tokens == self.subword_eos_idx).all(), \ @@ -109,12 +110,12 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) lm_probs = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), - log_probs=False, sample=None) # B x 1 x V + log_probs=False, sample=None, + ) # B x 1 x V cumsum_probs = torch.cumsum(lm_probs, dim=-1) # B x 1 x V nodes = [self.lexroot] * bsz else: - cumsum_probs = utils.get_incremental_state( - self, incremental_state, 'cumsum_probs') + cumsum_probs = utils.get_incremental_state(self, incremental_state, 'cumsum_probs') nodes = utils.get_incremental_state(self, incremental_state, 'nodes') assert len(nodes) == bsz w = prev_output_tokens.new([ @@ -144,8 +145,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): else: # no path in the tree nodes[i] = None - utils.set_incremental_state( - self, incremental_state, 'cumsum_probs', cumsum_probs) + utils.set_incremental_state(self, incremental_state, 'cumsum_probs', cumsum_probs) utils.set_incremental_state(self, incremental_state, 'nodes', nodes) # initialize out_probs (B x 1 x V) @@ -164,16 +164,13 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_probs[~batch_space_mask, :, self.subword_eos_idx] = self.zero # set transition probability to 1 for those whose node is out of the # tree, i.e. node is None (case 4 in Eqn. 15) - batch_node_none_mask = [] - for node in nodes: - batch_node_none_mask.append(node is None) - batch_node_none_mask = batch_space_mask.new(batch_node_none_mask) + batch_node_none_mask = batch_space_mask.new( + [node is None for node in nodes] + ) out_probs[batch_node_none_mask] = 1. else: # set out_probs to 0 - out_probs = cumsum_probs.new_full( - [bsz, 1, self.subword_vocab_size], self.zero, - ) + out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], self.zero) # compute parent probabilities for those whose node is not None sum_probs = cumsum_probs.new_full([bsz, 1], 1.) # default for root node diff --git a/espresso/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py index a8d79af82..872e38fe8 100644 --- a/espresso/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -35,8 +35,7 @@ def __init__(self, word_lm: FairseqLanguageModel, subword_dict: AsrDictionary, oov_penalty: float = 1e-4, - open_vocab: bool = True - ): + open_vocab: bool = True): decoder = _TensorizedLookaheadLanguageModelDecoder(word_lm, subword_dict, oov_penalty, open_vocab) super().__init__(decoder) @@ -105,14 +104,15 @@ def forward(self, w: torch.Tensor = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) # Z[Batch, Len=1] lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), - log_probs=False, sample=None) # R[Batch, 1, Vocab] + log_probs=False, sample=None, + ) # R[Batch, 1, Vocab] cumsum_probs: torch.Tensor = lm_probs.cumsum(dim=-1) # R[Batch, 1, Vocab] nodes: torch.Tensor = prev_output_tokens.new_full([bsz], self.tree.root_id) # Z_NodeId[Batch] - all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] else: # Not the first step cumsum_probs: torch.Tensor = utils.get_incremental_state( - self, incremental_state, 'cumsum_probs') # R[Batch, 1, Vocab] + self, incremental_state, 'cumsum_probs', + ) # R[Batch, 1, Vocab] nodes: torch.Tensor = utils.get_incremental_state(self, incremental_state, 'nodes') # Z_NodeId[Batch] assert nodes.size(0) == bsz w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(1) # Z[Batch, Len=1] @@ -123,9 +123,11 @@ def forward(self, # only for those whose prev_output_token is lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), - log_probs=False, sample=None) # R[Batch, 1, Vocab] + log_probs=False, sample=None, + ) # R[Batch, 1, Vocab] self.lm_decoder.masked_copy_incremental_state( - incremental_state, old_cached_state, batch_space_mask) # restore those not masked + incremental_state, old_cached_state, batch_space_mask, + ) # restore those not masked cumsum_probs[batch_space_mask] = lm_probs.cumsum(dim=-1)[batch_space_mask] prev_all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] @@ -135,7 +137,8 @@ def forward(self, nodes: torch.Tensor = (prev_all_children * mask.long()).sum(dim=1) # Z[Batch] # inter-word transition: go back to root nodes[batch_space_mask] = self.tree.root_id # Z[Batch] - all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] + + all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] utils.set_incremental_state(self, incremental_state, 'cumsum_probs', cumsum_probs) utils.set_incremental_state(self, incremental_state, 'nodes', nodes) @@ -159,8 +162,7 @@ def forward(self, # set transition probability to 1 for those whose node is out of the # tree, i.e. node is None (case 4 in Eqn. 15) - batch_node_none_mask = nodes.eq(self.tree.none_id) # B[Batch] - out_probs[batch_node_none_mask] = 1. + out_probs[nodes.eq(self.tree.none_id)] = 1. else: # set out_probs to 0 out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], self.zero) @@ -168,9 +170,8 @@ def forward(self, # compute parent probabilities for those whose node is not None left_ranges = self.tree.word_set_idx[nodes, 0] # Z[Batch] right_ranges = self.tree.word_set_idx[nodes, 1] # Z[Batch] - batch_node_not_root_mask = nodes.ne(self.tree.none_id) & nodes.ne(self.tree.root_id) # B[Batch] sum_probs = torch.where( - batch_node_not_root_mask, + nodes.ne(self.tree.none_id) & nodes.ne(self.tree.root_id), (cumsum_probs.squeeze(1).gather(-1, right_ranges.unsqueeze(-1)) - cumsum_probs.squeeze(1).gather(-1, left_ranges.unsqueeze(-1))).squeeze(-1), cumsum_probs.new([1.0]) @@ -188,7 +189,7 @@ def forward(self, out_probs.scatter_( -1, next_possible_tokens.unsqueeze(1), - cumsum_probs_of_all_children + cumsum_probs_of_all_children, ) # assume self.subword_pad_idx is the padding index in self.tree.prev_subword_idx out_probs[:, :, self.subword_pad_idx] = self.zero @@ -218,8 +219,6 @@ def forward(self, out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ torch.log(lm_probs)[batch_space_mask, :, self.word_eos_idx] - utils.set_incremental_state(self, incremental_state, 'out_logprobs', out_logprobs) - # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in # attention-based models diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 48fdac0c7..d025e7875 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -69,7 +69,7 @@ def _main(args, output_file): # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - args.path.split(':'), + args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, ) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 332ad8b2e..b8f0f55a0 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -142,6 +142,7 @@ def is_better(a, b): return should_stop_early.num_runs > args.patience +@metrics.aggregate('train') def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator @@ -166,34 +167,33 @@ def train(args, trainer, task, epoch_itr): max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) - with metrics.aggregate() as agg: - for samples in progress: - if hasattr(trainer.criterion, 'set_num_updates'): - trainer.criterion.set_num_updates(trainer.get_num_updates()) - - log_output = trainer.train_step(samples) - num_updates = trainer.get_num_updates() - if log_output is None: - continue - - # log mid-epoch stats - stats = get_training_stats(agg.get_smoothed_values()) - progress.log(stats, tag='train', step=num_updates) - - if ( - not args.disable_validation - and args.save_interval_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates > 0 - ): - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - - if num_updates >= max_update: - break + for samples in progress: + if hasattr(trainer.criterion, 'set_num_updates'): + trainer.criterion.set_num_updates(trainer.get_num_updates()) + + log_output = trainer.train_step(samples) + num_updates = trainer.get_num_updates() + if log_output is None: + continue + + # log mid-epoch stats + stats = get_training_stats(metrics.get_smoothed_values('train')) + progress.log(stats, tag='train', step=num_updates) + + if ( + not args.disable_validation + and args.save_interval_updates > 0 + and num_updates % args.save_interval_updates == 0 + and num_updates > 0 + ): + valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + + if num_updates >= max_update: + break # log end-of-epoch stats - stats = get_training_stats(agg.get_smoothed_values()) + stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters @@ -238,10 +238,9 @@ def validate(args, trainer, task, epoch_itr, subsets): no_progress_bar='simple' ) - # reset validation meters - metrics.reset_meters('valid') - - with metrics.aggregate() as agg: + # create a new root metrics aggregator so validation metrics + # don't pollute other aggregators (e.g., train meters) + with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 433f224b0..395f18158 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -105,7 +105,7 @@ def setup_task(cls, args, **kwargs): dictionary = None output_dictionary = None if args.data: - paths = args.data.split(":") + paths = args.data.split(os.pathsep) assert len(paths) > 0 dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \ else args.dict diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 09a4650fb..3a645e222 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -47,7 +47,6 @@ do_delta=false . ./utils/parse_options.sh lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} -wordlmdir=exp/wordlm_lstm${wordlm_affix:+_${wordlm_affix}} dir=exp/lstm${affix:+_$affix} if [ $stage -le 0 ]; then @@ -227,7 +226,7 @@ if [ $stage -le 5 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log - python3 ../../eval_lm.py $lmdatadir --user-dir espresso \ + python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file From d488b92022f310875a13713e4abf3f187a23535c Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 24 Jan 2020 21:01:46 -0500 Subject: [PATCH 064/119] isolate LSTMLanguageModel from speech_lstm.py and rename it to LSTMLanguageModelEspresso to avoid naming conflicts with the fairseq\'s LSTMLanguageModel introduced recently --- espresso/models/lstm_lm.py | 183 +++++++++++++++ espresso/models/speech_lstm.py | 216 ++++-------------- espresso/speech_train.py | 8 +- .../scheduled_sampling_rate_scheduler.py | 8 +- espresso/tools/wer.py | 4 +- 5 files changed, 234 insertions(+), 185 deletions(-) create mode 100644 espresso/models/lstm_lm.py diff --git a/espresso/models/lstm_lm.py b/espresso/models/lstm_lm.py new file mode 100644 index 000000000..151046d55 --- /dev/null +++ b/espresso/models/lstm_lm.py @@ -0,0 +1,183 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import options, utils +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.lstm import Embedding + +from espresso.models.speech_lstm import SpeechLSTMDecoder +from espresso.tasks.speech_recognition import SpeechRecognitionEspressoTask + + +DEFAULT_MAX_TARGET_POSITIONS = 1e5 + + +@register_model('lstm_lm_espresso') +class LSTMLanguageModelEspresso(FairseqLanguageModel): + def __init__(self, decoder, args): + super().__init__(decoder) + self.is_wordlm = args.is_wordlm + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-freeze-embed', action='store_true', + help='freeze decoder embeddings') + parser.add_argument('--decoder-hidden-size', type=int, metavar='N', + help='decoder hidden size') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='number of decoder layers') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + parser.add_argument('--share-embed', + type=lambda x: options.eval_bool(x), + help='share input and output embeddings') + parser.add_argument('--is-wordlm', action='store_true', + help='whether it is word LM or subword LM. Only ' + 'relevant for ASR decoding with LM, and it determines ' + 'how the underlying decoder instance gets the dictionary' + 'from the task instance when calling cls.build_model()') + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument('--decoder-dropout-in', type=float, metavar='D', + help='dropout probability for decoder input embedding') + parser.add_argument('--decoder-dropout-out', type=float, metavar='D', + help='dropout probability for decoder output') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_lm_architecture(args) + + if getattr(args, 'max_target_positions', None) is not None: + max_target_positions = args.max_target_positions + else: + max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + if args.is_wordlm and hasattr(task, 'word_dictionary'): + dictionary = task.word_dictionary + elif isinstance(task, SpeechRecognitionEspressoTask): + dictionary = task.target_dictionary + else: + dictionary = task.source_dictionary + + # separate decoder input embeddings + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, + dictionary, + args.decoder_embed_dim + ) + # one last double check of parameter combinations + if args.share_embed and ( + args.decoder_embed_dim != args.decoder_out_embed_dim): + raise ValueError( + '--share-embed requires ' + '--decoder-embed-dim to match --decoder-out-embed-dim' + ) + + if args.decoder_freeze_embed: + pretrained_decoder_embed.weight.requires_grad = False + + decoder = SpeechLSTMDecoder( + dictionary=dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + attn_type=None, + encoder_output_units=0, + pretrained_embed=pretrained_decoder_embed, + share_input_output_embed=args.share_embed, + adaptive_softmax_cutoff=( + options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == 'adaptive_loss' else None + ), + max_target_positions=max_target_positions, + ) + return cls(decoder, args) + + +@register_model_architecture('lstm_lm_espresso', 'lstm_lm_espresso') +def base_lm_architecture(args): + args.dropout = getattr(args, 'dropout', 0.1) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 650) + args.decoder_layers = getattr(args, 'decoder_layers', 2) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 650) + args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', False) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) + args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) + args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) + args.share_embed = getattr(args, 'share_embed', False) + args.is_wordlm = getattr(args, 'is_wordlm', False) + + +@register_model_architecture('lstm_lm_espresso', 'lstm_lm_wsj') +def lstm_lm_wsj(args): + base_lm_architecture(args) + + +@register_model_architecture('lstm_lm_espresso', 'lstm_lm_librispeech') +def lstm_lm_librispeech(args): + args.dropout = getattr(args, 'dropout', 0.0) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 800) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 800) + args.decoder_layers = getattr(args, 'decoder_layers', 4) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 800) + args.share_embed = getattr(args, 'share_embed', True) + base_lm_architecture(args) + + +@register_model_architecture('lstm_lm_espresso', 'lstm_lm_swbd') +def lstm_lm_swbd(args): + args.dropout = getattr(args, 'dropout', 0.3) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1800) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1800) + args.decoder_layers = getattr(args, 'decoder_layers', 3) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1800) + args.share_embed = getattr(args, 'share_embed', True) + base_lm_architecture(args) + + +@register_model_architecture('lstm_lm_espresso', 'lstm_wordlm_wsj') +def lstm_wordlm_wsj(args): + args.dropout = getattr(args, 'dropout', 0.35) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1200) + args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1200) + args.decoder_layers = getattr(args, 'decoder_layers', 3) + args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1200) + args.share_embed = getattr(args, 'share_embed', True) + args.is_wordlm = True + base_lm_architecture(args) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 0fbb8a69c..c2c828fcc 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -14,7 +14,6 @@ FairseqDecoder, FairseqEncoder, FairseqIncrementalDecoder, - FairseqLanguageModel, FairseqEncoderDecoderModel, register_model, register_model_architecture, @@ -28,11 +27,14 @@ from fairseq.modules import AdaptiveSoftmax from espresso.modules import speech_attention -from espresso.tasks.speech_recognition import SpeechRecognitionEspressoTask from espresso.tools.scheduled_sampling_rate_scheduler import ScheduledSamplingRateScheduler import espresso.tools.utils as speech_utils +DEFAULT_MAX_SOURCE_POSITIONS = 1e5 +DEFAULT_MAX_TARGET_POSITIONS = 1e5 + + logger = logging.getLogger(__name__) @@ -114,7 +116,7 @@ def add_args(parser): # Scheduled sampling options parser.add_argument('--scheduled-sampling-probs', type=lambda p: options.eval_str_list(p), - metavar='P_1,P_2,...,P_N', default=1.0, + metavar='P_1,P_2,...,P_N', default=[1.0], help='scheduled sampling probabilities of sampling the truth ' 'labels for N epochs starting from --start-schedule-sampling-epoch; ' 'all later epochs using P_N') @@ -129,6 +131,9 @@ def build_model(cls, args, task): # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) + max_source_positions = getattr(args, 'max_source_positions', DEFAULT_MAX_SOURCE_POSITIONS) + max_target_positions = getattr(args, 'max_target_positions', DEFAULT_MAX_TARGET_POSITIONS) + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) padding_idx = dictionary.pad() @@ -207,6 +212,7 @@ def eval_str_nested_list_or_tuple(x, type=int): dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, residual=args.encoder_rnn_residual, + max_source_positions=max_source_positions, ) decoder = SpeechLSTMDecoder( dictionary=task.target_dictionary, @@ -227,6 +233,7 @@ def eval_str_nested_list_or_tuple(x, type=int): options.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == 'adaptive_loss' else None ), + max_target_positions=max_target_positions, scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, ) pretrained_lm = None @@ -254,107 +261,6 @@ def max_decoder_positions(self): min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) -@register_model('lstm_lm') -class LSTMLanguageModel(FairseqLanguageModel): - def __init__(self, decoder, args): - super().__init__(decoder) - self.is_wordlm = args.is_wordlm - - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-embed-path', type=str, metavar='STR', - help='path to pre-trained decoder embedding') - parser.add_argument('--decoder-freeze-embed', action='store_true', - help='freeze decoder embeddings') - parser.add_argument('--decoder-hidden-size', type=int, metavar='N', - help='decoder hidden size') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='number of decoder layers') - parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', - help='decoder output embedding dimension') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion') - parser.add_argument('--share-embed', - type=lambda x: options.eval_bool(x), - help='share input and output embeddings') - parser.add_argument('--is-wordlm', action='store_true', - help='whether it is word LM or subword LM. Only ' - 'relevant for ASR decoding with LM, and it determines ' - 'how the underlying decoder instance gets the dictionary' - 'from the task instance when calling cls.build_model()') - - # Granular dropout settings (if not specified these default to --dropout) - parser.add_argument('--decoder-dropout-in', type=float, metavar='D', - help='dropout probability for decoder input embedding') - parser.add_argument('--decoder-dropout-out', type=float, metavar='D', - help='dropout probability for decoder output') - # fmt: on - - @classmethod - def build_model(cls, args, task): - """Build a new model instance.""" - # make sure all arguments are present in older models - base_lm_architecture(args) - - def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): - num_embeddings = len(dictionary) - padding_idx = dictionary.pad() - embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) - embed_dict = utils.parse_embedding(embed_path) - utils.print_embed_overlap(embed_dict, dictionary) - return utils.load_embedding(embed_dict, dictionary, embed_tokens) - - if args.is_wordlm and hasattr(task, 'word_dictionary'): - dictionary = task.word_dictionary - elif isinstance(task, SpeechRecognitionEspressoTask): - dictionary = task.target_dictionary - else: - dictionary = task.source_dictionary - - # separate decoder input embeddings - pretrained_decoder_embed = None - if args.decoder_embed_path: - pretrained_decoder_embed = load_pretrained_embedding_from_file( - args.decoder_embed_path, - dictionary, - args.decoder_embed_dim - ) - # one last double check of parameter combinations - if args.share_embed and ( - args.decoder_embed_dim != args.decoder_out_embed_dim): - raise ValueError( - '--share-embed requires ' - '--decoder-embed-dim to match --decoder-out-embed-dim' - ) - - if args.decoder_freeze_embed: - pretrained_decoder_embed.weight.requires_grad = False - - decoder = SpeechLSTMDecoder( - dictionary=dictionary, - embed_dim=args.decoder_embed_dim, - hidden_size=args.decoder_hidden_size, - out_embed_dim=args.decoder_out_embed_dim, - num_layers=args.decoder_layers, - dropout_in=args.decoder_dropout_in, - dropout_out=args.decoder_dropout_out, - pretrained_embed=pretrained_decoder_embed, - share_input_output_embed=args.share_embed, - adaptive_softmax_cutoff=( - options.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == 'adaptive_loss' else None - ), - ) - return LSTMLanguageModel(decoder, args) - - class ConvBNReLU(nn.Module): """Sequence of convolution-BatchNorm-ReLU layers.""" def __init__(self, out_channels, kernel_sizes, strides, in_channels=1): @@ -415,6 +321,7 @@ def __init__( self, conv_layers_before=None, input_size=83, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, residual=False, left_pad=False, pretrained_embed=None, padding_value=0., + max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__(None) # no src dictionary self.conv_layers_before = conv_layers_before @@ -424,6 +331,7 @@ def __init__( self.bidirectional = bidirectional self.hidden_size = hidden_size self.residual = residual + self.max_source_positions = max_source_positions self.lstm = nn.ModuleList([ LSTM( @@ -505,7 +413,7 @@ def reorder_encoder_out(self, encoder_out, new_order): def max_positions(self): """Maximum input length supported by the encoder.""" - return int(1e5) # an arbitrary large number + return self.max_source_positions class SpeechLSTMDecoder(FairseqIncrementalDecoder): @@ -515,6 +423,7 @@ def __init__( num_layers=1, dropout_in=0.1, dropout_out=0.1, encoder_output_units=0, attn_type=None, attn_dim=0, need_attn=False, residual=False, pretrained_embed=None, share_input_output_embed=False, adaptive_softmax_cutoff=None, + max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, scheduled_sampling_rate_scheduler=None, ): super().__init__(dictionary) @@ -528,6 +437,7 @@ def __init__( encoder_output_units = 0 self.need_attn = need_attn self.residual = residual + self.max_target_positions = max_target_positions self.adaptive_softmax = None num_embeddings = len(dictionary) @@ -634,13 +544,17 @@ def extract_features( - the decoder's features of shape `(batch, tgt_len, embed_dim)` - attention weights of shape `(batch, tgt_len, src_len)` """ - if self.attention is not None: - assert encoder_out is not None + if encoder_out is not None: + assert self.attention is not None encoder_padding_mask = encoder_out['encoder_padding_mask'] encoder_out = encoder_out['encoder_out'] # get outputs from encoder encoder_outs = encoder_out[0] srclen = encoder_outs.size(0) + else: + encoder_padding_mask = None + encoder_out = None + srclen = None if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] @@ -659,18 +573,20 @@ def extract_features( prev_hiddens, prev_cells, input_feed = cached_state else: num_layers = len(self.layers) - prev_hiddens = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)] - prev_cells = [x.new_zeros(bsz, self.hidden_size) for i in range(num_layers)] + zero_state = x.new_zeros(bsz, self.hidden_size) + prev_hiddens = [zero_state for i in range(num_layers)] + prev_cells = [zero_state for i in range(num_layers)] input_feed = x.new_zeros(bsz, self.encoder_output_units) \ - if self.attention is not None else None + if encoder_out is not None else None - if self.attention is not None: - attn_scores = x.new_zeros(srclen, seqlen, bsz) + attn_scores = x.new_zeros(srclen, seqlen, bsz) if encoder_out is not None else None outs = [] for j in range(seqlen): # input feeding: concatenate context vector from previous time step - input = torch.cat((x[j, :, :], input_feed), dim=1) \ - if input_feed is not None else x[j, :, :] + if input_feed is not None: + input = torch.cat((x[j, :, :], input_feed), dim=1) + else: + input = x[j, :, :] for i, rnn in enumerate(self.layers): # recurrent cell @@ -679,7 +595,7 @@ def extract_features( prev_layer_hidden = input[:, :hidden.size(1)] # compute and apply attention using the 1st layer's hidden state - if self.attention is not None: + if encoder_out is not None: if i == 0: context, attn_scores[:, j, :], _ = self.attention( hidden, encoder_outs, encoder_padding_mask, @@ -692,7 +608,7 @@ def extract_features( input = hidden input = F.dropout(input, p=self.dropout_out, training=self.training) if self.residual and i > 0: - if self.attention is not None: + if encoder_out is not None: hidden_sum = input[:, :hidden.size(1)] + prev_layer_hidden input = torch.cat((hidden_sum, input[:, hidden.size(1):]), dim=1) else: @@ -703,7 +619,8 @@ def extract_features( prev_cells[i] = cell # input feeding - input_feed = context if self.attention is not None else None + if input_feed is not None: + input_feed = context # save final output outs.append(input) @@ -726,7 +643,7 @@ def extract_features( x = F.dropout(x, p=self.dropout_out, training=self.training) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen - if not self.training and self.attention is not None and self.need_attn: + if not self.training and encoder_out is not None and self.need_attn: attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None @@ -753,7 +670,10 @@ def reorder_incremental_state(self, incremental_state, new_order): def reorder_state(state): if isinstance(state, list): return [reorder_state(state_i) for state_i in state] - return state.index_select(0, new_order) if state is not None else None + elif state is not None: + return state.index_select(0, new_order) + else: + return None new_state = tuple(map(reorder_state, cached_state)) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) @@ -786,7 +706,7 @@ def mask_copy_state(state, another_state): def max_positions(self): """Maximum output length supported by the decoder.""" - return int(1e5) # an arbitrary large number + return self.max_target_positions def make_generation_fast_(self, need_attn=False, **kwargs): self.need_attn = need_attn @@ -815,62 +735,6 @@ def Convolution2d(in_channels, out_channels, kernel_size, stride): return m -@register_model_architecture('lstm_lm', 'lstm_lm') -def base_lm_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 650) - args.decoder_layers = getattr(args, 'decoder_layers', 2) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 650) - args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', False) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.share_embed = getattr(args, 'share_embed', False) - args.is_wordlm = getattr(args, 'is_wordlm', False) - - -@register_model_architecture('lstm_lm', 'lstm_lm_wsj') -def lstm_lm_wsj(args): - base_lm_architecture(args) - - -@register_model_architecture('lstm_lm', 'lstm_lm_librispeech') -def lstm_lm_librispeech(args): - args.dropout = getattr(args, 'dropout', 0.0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 800) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 800) - args.decoder_layers = getattr(args, 'decoder_layers', 4) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 800) - args.share_embed = getattr(args, 'share_embed', True) - base_lm_architecture(args) - - -@register_model_architecture('lstm_lm', 'lstm_lm_swbd') -def lstm_lm_swbd(args): - args.dropout = getattr(args, 'dropout', 0.3) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1800) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1800) - args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1800) - args.share_embed = getattr(args, 'share_embed', True) - base_lm_architecture(args) - - -@register_model_architecture('lstm_lm', 'lstm_wordlm_wsj') -def lstm_wordlm_wsj(args): - args.dropout = getattr(args, 'dropout', 0.35) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1200) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1200) - args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1200) - args.share_embed = getattr(args, 'share_embed', True) - args.is_wordlm = True - base_lm_architecture(args) - - @register_model_architecture('speech_lstm', 'speech_lstm') def base_architecture(args): args.dropout = getattr(args, 'dropout', 0.4) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index b8f0f55a0..d92a6064e 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -118,9 +118,11 @@ def main(args, init_distributed=False): logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) break - reload_dataset = len(args.train_feat_files) > 1 - # sharded data: get train iterator for next epoch - epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) + epoch_itr = trainer.get_train_iterator( + epoch_itr.epoch, + # sharded data: get train iterator for next epoch + load_dataset=(len(args.train_feat_files) > 1), + ) train_meter.stop() logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) diff --git a/espresso/tools/scheduled_sampling_rate_scheduler.py b/espresso/tools/scheduled_sampling_rate_scheduler.py index 2b9b029d0..9cc438441 100644 --- a/espresso/tools/scheduled_sampling_rate_scheduler.py +++ b/espresso/tools/scheduled_sampling_rate_scheduler.py @@ -28,9 +28,11 @@ def __init__( def step(self, epoch: int) -> float: if ( - (len(self.scheduled_sampling_probs) > 1 or - self.scheduled_sampling_probs[0] < 1.0) and - epoch >= self.start_scheduled_sampling_epoch + ( + len(self.scheduled_sampling_probs) > 1 + or self.scheduled_sampling_probs[0] < 1.0 + ) + and epoch >= self.start_scheduled_sampling_epoch ): prob = self.scheduled_sampling_probs[ min(epoch - self.start_scheduled_sampling_epoch, diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index f8b7618b5..c5190d247 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -3,11 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import Counter, OrderedDict import logging import re -import sys - -from collections import Counter, OrderedDict import espresso.tools.utils as speech_utils From ad102bdd1e1bc4a34d4df680519074673106a4d3 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 30 Jan 2020 20:23:01 -0500 Subject: [PATCH 065/119] code adaptation/changes according to the commits on Jan 30, 2020 --- espresso/tools/simple_greedy_decoder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index acb929ce4..593b2041a 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -108,6 +108,10 @@ def _decode(self, model, sample, bos_token=None, **kwargs): tokens[is_eos, step + 1] = self.eos if self.for_validation and step < target.size(1): lprobs[:, step, :] = log_probs + + # Record attention scores + if type(avg_attn_scores) is list: + avg_attn_scores = avg_attn_scores[0] if avg_attn_scores is not None: if attn is None: attn = avg_attn_scores.new(bsz, max_encoder_output_length, max_len + 2) From 19ec973517d4b33cf6964ae9c5417b3eee3828d1 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 12 Feb 2020 13:35:59 -0500 Subject: [PATCH 066/119] code adaptation/changes according to the commits on Feb 12, 2020 --- espresso/criterions/cross_entropy_v2.py | 2 +- espresso/criterions/label_smoothed_cross_entropy_v2.py | 4 ++-- espresso/tasks/speech_recognition.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py index 5c97873b2..95b4a823c 100644 --- a/espresso/criterions/cross_entropy_v2.py +++ b/espresso/criterions/cross_entropy_v2.py @@ -50,7 +50,7 @@ def forward(self, model, sample, reduce=True): loss, _, lprobs = self.compute_loss(model, net_output, sample, reduce=reduce) sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, + 'loss': loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['target'].size(0), 'sample_size': sample_size, diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index 9b4b797b3..a441ab48c 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -126,8 +126,8 @@ def forward(self, model, sample, reduce=True): ) sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] logging_output = { - 'loss': utils.item(loss.data) if reduce else loss.data, - 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, + 'loss': loss.data, + 'nll_loss': nll_loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['target'].size(0), 'sample_size': sample_size, diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 227569703..21c6c9cf8 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -8,7 +8,7 @@ import torch -from fairseq import metrics, options, search +from fairseq import metrics, options, search, utils from fairseq.data import ConcatDataset from fairseq.tasks import FairseqTask, register_task @@ -316,10 +316,10 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, lm_weigh def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) - word_error = sum(log.get('word_error', 0) for log in logging_outputs) - word_count = sum(log.get('word_count', 0) for log in logging_outputs) - char_error = sum(log.get('char_error', 0) for log in logging_outputs) - char_count = sum(log.get('char_count', 0) for log in logging_outputs) + word_error = utils.item(sum(log.get('word_error', 0) for log in logging_outputs)) + word_count = utils.item(sum(log.get('word_count', 0) for log in logging_outputs)) + char_error = utils.item(sum(log.get('char_error', 0) for log in logging_outputs)) + char_count = utils.item(sum(log.get('char_count', 0) for log in logging_outputs)) if word_count > 0: metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) if char_count > 0: From 9c69029b82ecf11bf4d39a498a14cc8f8ddcc847 Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Thu, 13 Feb 2020 21:26:32 -0500 Subject: [PATCH 067/119] add options to accept utt2num_frames files to speed up the data loading (#22) * add options to accept utt2num_frames files to speed up the data loading --- espresso/data/scp_text_dataset.py | 27 ++++++++++++----- espresso/tasks/speech_recognition.py | 43 ++++++++++++++++++++++++---- examples/asr_librispeech/run.sh | 9 ++++-- examples/asr_swbd/run.sh | 9 ++++-- examples/asr_wsj/run.sh | 10 +++++-- 5 files changed, 76 insertions(+), 22 deletions(-) diff --git a/espresso/data/scp_text_dataset.py b/espresso/data/scp_text_dataset.py index 436995d6c..898480348 100644 --- a/espresso/data/scp_text_dataset.py +++ b/espresso/data/scp_text_dataset.py @@ -22,25 +22,38 @@ class ScpDataset(torch.utils.data.Dataset): every time each entry is inquired, thus incurs the most intensive I/O. """ - def __init__(self, path): + def __init__(self, path, utt2num_frames_path=None): super().__init__() self.dtype = np.float - self.read_scp(path) + self.read_scp(path, utt2num_frames_path) - def read_scp(self, path): + def read_scp(self, path, utt2num_frames_path=None): with open(path, 'r', encoding='utf-8') as f: scp_entries = [line.strip().split(None, 1) for line in f] self.utt_ids = [entry[0] for entry in scp_entries] self.extended_filenames = [entry[1] for entry in scp_entries] self.size = len(scp_entries) # number of utterances self.sizes = [] # length of each utterance + if utt2num_frames_path is not None: + with open(utt2num_frames_path, 'r', encoding='utf-8') as f: + i = 0 + for line in f: + utt_id, num_frames = line.strip().split(None, 1) + assert utt_id == self.utt_ids[i], \ + 'utterance ids mismatch: ' + utt_id + ' vs. ' + self.utt_ids[i] + self.sizes.append(int(num_frames)) + i += 1 + for filename in self.extended_filenames: try: feat = kaldi_io.read_mat(filename) except Exception: raise Exception('failed to read feature matrix {}.'.format(filename)) assert feat is not None and isinstance(feat, np.ndarray) + if len(self.sizes) == self.size: + break self.sizes.append(feat.shape[0]) + self.sizes = np.array(self.sizes, dtype=np.int32) self.feat_dim = feat.shape[1] # feature dimension @@ -84,8 +97,8 @@ class ScpCachedDataset(ScpDataset): It balances the I/O efficiency and memory usage. """ - def __init__(self, path, ordered_prefetch=False, cache_size=4096): - super().__init__(path) + def __init__(self, path, utt2num_frames_path=None, ordered_prefetch=False, cache_size=4096): + super().__init__(path, utt2num_frames_path) self.cache = None self.cache_index = {} self.cache_size = cache_size # in terms of number of examples @@ -156,8 +169,8 @@ class ScpInMemoryDataset(ScpDataset): It has the maximum memory usage and least I/O. """ - def __init__(self, path): - super().__init__(path) + def __init__(self, path, utt2num_frames_path=None): + super().__init__(path, utt2num_frames_path) self.read_data() def read_data(self): diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 21c6c9cf8..14c0961cf 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -59,21 +59,34 @@ def add_args(parser): help='path(s) to text file(s) for training, where ' 'each should matches with one in --train-feat-files, ' 'will be iterated upon during epochs in round-robin manner') + parser.add_argument('--train-utt2num-frames-files', nargs='+', default=None, + help='path(s) to utt2num_frames file(s) for training. if not None, ' + 'each should matches with one in --train-feat-files, ' + 'will be iterated upon during epochs in round-robin manner') parser.add_argument('--valid-feat-files', nargs='+', help='path(s) to scp feature file(s) for validation') parser.add_argument('--valid-text-files', nargs='+', help='path(s) to text file(s) for validation, where ' 'each should matches with one in --valid-feat-files') + parser.add_argument('--valid-utt2num-frames-files', nargs='+', default=None, + help='path(s) to utt2num_frames file(s) for validation. if not None, ' + 'each should matches with one in --valid-feat-files') parser.add_argument('--test-feat-files', nargs='+', help='path(s) to scp feature file(s) for test') - parser.add_argument('--test-text-files', nargs='*', default=None, + parser.add_argument('--test-text-files', nargs='+', default=None, help='path(s) to text file(s) for test. if not None, ' 'each one should matches with one in --test-feat-files') + parser.add_argument('--test-utt2num-frames-files', nargs='+', default=None, + help='path(s) to utt2num_frames file(s) for test. if not None, ' + 'each should matches with one in --test-feat-files') parser.add_argument('--train-subset-feat-files', nargs='+', help='path(s) to scp feature file(s) for validation') parser.add_argument('--train-subset-text-files', nargs='+', help='path(s) to text file(s) for validation, where ' 'each should matches with one in --train-subset-feat-files') + parser.add_argument('--train-subset-utt2num-frames-files', nargs='+', default=None, + help='path(s) to utt2num_frames file(s) for validation. if not None, ' + 'each should matches with one in --train-subset-feat-files') parser.add_argument('--dict', default=None, type=str, help='path to the dictionary') parser.add_argument('--non-lang-syms', default=None, type=str, @@ -159,29 +172,47 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): if split == 'train': feat_files = self.args.train_feat_files text_files = self.args.train_text_files + utt2num_frames_files = self.args.train_utt2num_frames_files # can be None assert len(feat_files) > 0 and len(feat_files) == len(text_files) + assert utt2num_frames_files is None or len(feat_files) == len(utt2num_frames_files) feat_files = [feat_files[epoch % len(feat_files)]] text_files = [text_files[epoch % len(text_files)]] + if utt2num_frames_files is not None: + utt2num_frames_files = [utt2num_frames_files[epoch % len(utt2num_frames_files)]] + else: + utt2num_frames_files = [None] elif split == 'valid': feat_files = self.args.valid_feat_files text_files = self.args.valid_text_files + utt2num_frames_files = self.args.valid_utt2num_frames_files # can be None + if utt2num_frames_files is None: + utt2num_frames_files = [None] * len(feat_files) elif split == 'test': feat_files = self.args.test_feat_files - text_files = self.args.test_text_files # can be empty + text_files = self.args.test_text_files # can be None + utt2num_frames_files = self.args.test_utt2num_frames_files # can be None if text_files is None: text_files = [None] * len(feat_files) + if utt2num_frames_files is None: + utt2num_frames_files = [None] * len(feat_files) elif split == 'train_subset': feat_files = self.args.train_subset_feat_files text_files = self.args.train_subset_text_files + utt2num_frames_files = self.args.train_subset_utt2num_frames_files # can be None + if utt2num_frames_files is None: + utt2num_frames_files = [None] * len(feat_files) else: raise ValueError('split should be one of "train", "valid", "test", "train_subset"') - assert len(feat_files) > 0 and len(feat_files) == len(text_files) - file_pairs = zip(feat_files, text_files) - for feat, text in file_pairs: + assert len(feat_files) > 0 and len(feat_files) == len(text_files) and \ + len(feat_files) == len(utt2num_frames_files) + file_tuples = zip(feat_files, text_files, utt2num_frames_files) + for feat, text, utt2num_frames in file_tuples: assert ScpCachedDataset.exists(feat), feat + ' does not exists' assert text is None or AsrTextDataset.exists(text), text + ' does not exists' - src_datasets.append(ScpCachedDataset(feat, ordered_prefetch=True)) + assert utt2num_frames is None or ScpCachedDataset.exists(utt2num_frames), \ + utt2num_frames + ' does not exists' + src_datasets.append(ScpCachedDataset(feat, utt2num_frames, ordered_prefetch=True)) logger.info('{} {} examples'.format(feat, len(src_datasets[-1]))) if text is not None: tgt_datasets.append(AsrTextDataset(text, self.dictionary)) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 59a3fc131..a09335a0b 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -197,8 +197,10 @@ fi train_feat=$train_feat_dir/feats.scp train_token_text=data/$train_set/token_text +train_utt2num_frames=data/$train_set/utt2num_frames valid_feat=$valid_feat_dir/feats.scp valid_token_text=data/$valid_set/token_text +valid_utt2num_frames=data/$valid_set/utt2num_frames if [ ${stage} -le 7 ]; then echo "Stage 7: Model Training" valid_subset=valid @@ -217,8 +219,8 @@ if [ ${stage} -le 7 ]; then --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ - --train-feat-files $train_feat --train-text-files $train_token_text \ - --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ + --train-feat-files $train_feat --train-text-files $train_token_text --train-utt2num-frames-files $train_utt2num_frames \ + --valid-feat-files $valid_feat --valid-text-files $valid_token_text --valid-utt2num-frames-files $valid_utt2num_frames \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 2>&1 | tee $log_file fi @@ -236,10 +238,11 @@ if [ ${stage} -le 8 ]; then for dataset in $test_set; do feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp text=data/$dataset/token_text + utt2num_frames=data/$dataset/utt2num_frames decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --max-sentences 24 \ - --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ + --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text --test-utt2num-frames-files $utt2num_frames \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 3a645e222..db3bc8ee5 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -235,8 +235,10 @@ fi train_feat=$train_feat_dir/feats.scp train_token_text=data/$train_set/token_text +train_utt2num_frames=data/$train_set/utt2num_frames valid_feat=$valid_feat_dir/feats.scp valid_token_text=data/$valid_set/token_text +valid_utt2num_frames=data/$valid_set/utt2num_frames if [ $stage -le 6 ]; then echo "Stage 6: Model Training" valid_subset=valid @@ -257,8 +259,8 @@ if [ $stage -le 6 ]; then --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6 \ - --train-feat-files $train_feat --train-text-files $train_token_text \ - --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ + --train-feat-files $train_feat --train-text-files $train_token_text --train-utt2num-frames-files $train_utt2num_frames \ + --valid-feat-files $valid_feat --valid-text-files $valid_token_text --valid-utt2num-frames-files $valid_utt2num_frames \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi @@ -277,12 +279,13 @@ if [ $stage -le 7 ]; then [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $test_set; do feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp + utt2num_frames=data/$dataset/utt2num_frames decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} # only score train_dev with built-in scorer text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ - --num-shards 1 --shard-id 0 --test-feat-files $feat $text_opt \ + --num-shards 1 --shard-id 0 --test-feat-files $feat $text_opt --test-utt2num-frames-files $utt2num_frames \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index a50f81a8b..20b2923ff 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -249,8 +249,10 @@ fi train_feat=$train_feat_dir/feats.scp train_token_text=data/$train_set/token_text +train_utt2num_frames=data/$train_set/utt2num_frames valid_feat=$valid_feat_dir/feats.scp valid_token_text=data/$valid_set/token_text +valid_utt2num_frames=data/$valid_set/utt2num_frames if [ ${stage} -le 8 ]; then echo "Stage 8: Model Training" opts="" @@ -259,6 +261,7 @@ if [ ${stage} -le 8 ]; then valid_subset="$valid_subset,train_subset" opts="$opts --train-subset-feat-files $train_subset_feat_dir/feats.scp" opts="$opts --train-subset-text-files data/${train_set}_${train_subset_size}/token_text" + opts="$opts --train-subset-utt2num-frames-files data/${train_set}_${train_subset_size}/utt2num_frames" fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" mkdir -p $dir/logs @@ -276,8 +279,8 @@ if [ ${stage} -le 8 ]; then --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.05 --smoothing-type temporal \ --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6 \ - --train-feat-files $train_feat --train-text-files $train_token_text \ - --valid-feat-files $valid_feat --valid-text-files $valid_token_text \ + --train-feat-files $train_feat --train-text-files $train_token_text --train-utt2num-frames-files $train_utt2num_frames \ + --valid-feat-files $valid_feat --valid-text-files $valid_token_text --valid-utt2num-frames-files $valid_utt2num_frames \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi @@ -306,10 +309,11 @@ if [ ${stage} -le 9 ]; then feat=$test_feat_dir/feats.scp fi text=data/$dataset/token_text + utt2num_frames=data/$dataset/utt2num_frames decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --max-sentences 32 \ - --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text \ + --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text --test-utt2num-frames-files $utt2num_frames \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ From 698305fa6379ff0b57e327969e413e8b31537b28 Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Sat, 15 Feb 2020 16:43:51 -0500 Subject: [PATCH 068/119] use json files to simplify the cli options for input data (#23) --- espresso/data/scp_text_dataset.py | 92 ++++--- espresso/speech_recognize.py | 3 +- espresso/speech_train.py | 3 +- espresso/tasks/speech_recognition.py | 358 ++++++++++++-------------- espresso/tools/asr_prep_json.py | 65 +++++ espresso/tools/wer.py | 3 + examples/asr_librispeech/run.sh | 42 +-- examples/asr_swbd/run.sh | 45 ++-- examples/asr_wsj/run.sh | 57 ++-- tests/espresso/test_speech_dataset.py | 91 ++++--- 10 files changed, 413 insertions(+), 346 deletions(-) create mode 100755 espresso/tools/asr_prep_json.py diff --git a/espresso/data/scp_text_dataset.py b/espresso/data/scp_text_dataset.py index 898480348..c81ddea25 100644 --- a/espresso/data/scp_text_dataset.py +++ b/espresso/data/scp_text_dataset.py @@ -4,8 +4,10 @@ # LICENSE file in the root directory of this source tree. import os +from typing import List, Optional import numpy as np + import torch try: @@ -22,44 +24,34 @@ class ScpDataset(torch.utils.data.Dataset): every time each entry is inquired, thus incurs the most intensive I/O. """ - def __init__(self, path, utt2num_frames_path=None): + def __init__( + self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, + ): super().__init__() + assert len(utt_ids) == len(rxfiles) self.dtype = np.float - self.read_scp(path, utt2num_frames_path) - - def read_scp(self, path, utt2num_frames_path=None): - with open(path, 'r', encoding='utf-8') as f: - scp_entries = [line.strip().split(None, 1) for line in f] - self.utt_ids = [entry[0] for entry in scp_entries] - self.extended_filenames = [entry[1] for entry in scp_entries] - self.size = len(scp_entries) # number of utterances + self.utt_ids = utt_ids + self.rxfiles = rxfiles + self.size = len(utt_ids) # number of utterances self.sizes = [] # length of each utterance - if utt2num_frames_path is not None: - with open(utt2num_frames_path, 'r', encoding='utf-8') as f: - i = 0 - for line in f: - utt_id, num_frames = line.strip().split(None, 1) - assert utt_id == self.utt_ids[i], \ - 'utterance ids mismatch: ' + utt_id + ' vs. ' + self.utt_ids[i] - self.sizes.append(int(num_frames)) - i += 1 - - for filename in self.extended_filenames: + if utt2num_frames is not None and len(utt2num_frames) > 0: + assert len(utt2num_frames) == self.size + self.sizes = utt2num_frames + + for rxfile in self.rxfiles: try: - feat = kaldi_io.read_mat(filename) + feat = kaldi_io.read_mat(rxfile) except Exception: - raise Exception('failed to read feature matrix {}.'.format(filename)) + raise Exception('failed to read feature matrix {}.'.format(rxfile)) assert feat is not None and isinstance(feat, np.ndarray) if len(self.sizes) == self.size: break self.sizes.append(feat.shape[0]) + assert len(self.sizes) == self.size self.sizes = np.array(self.sizes, dtype=np.int32) self.feat_dim = feat.shape[1] # feature dimension - assert len(self.utt_ids) == len(self.extended_filenames) and \ - len(self.utt_ids) == len(self.sizes) - def check_index(self, i): if i < 0 or i >= self.size: raise IndexError('index out of range') @@ -71,14 +63,14 @@ def filter_and_reorder(self, indices): assert len(np.unique(indices)) == len(indices), \ 'Duplicate elements in indices.' self.utt_ids = [self.utt_ids[i] for i in indices] - self.extended_filenames = [self.extended_filenames[i] for i in indices] + self.rxfiles = [self.rxfiles[i] for i in indices] self.sizes = self.sizes[indices] self.size = len(self.utt_ids) self.ordered_indices = list(range(self.size)) def __getitem__(self, i): self.check_index(i) - feat = kaldi_io.read_mat(self.extended_filenames[i]) + feat = kaldi_io.read_mat(self.rxfiles[i]) item = torch.from_numpy(feat).float() return item @@ -97,8 +89,11 @@ class ScpCachedDataset(ScpDataset): It balances the I/O efficiency and memory usage. """ - def __init__(self, path, utt2num_frames_path=None, ordered_prefetch=False, cache_size=4096): - super().__init__(path, utt2num_frames_path) + def __init__( + self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, + ordered_prefetch=False, cache_size=4096, + ): + super().__init__(utt_ids, rxfiles, utt2num_frames=utt2num_frames) self.cache = None self.cache_index = {} self.cache_size = cache_size # in terms of number of examples @@ -155,7 +150,7 @@ def __getitem__(self, i): self.cache_index[idx] = ptx length = self.sizes[idx] dst = self.cache[ptx: ptx + length] - np.copyto(dst, kaldi_io.read_mat(self.extended_filenames[idx])) + np.copyto(dst, kaldi_io.read_mat(self.rxfiles[idx])) ptx += length ptx = self.cache_index[i] @@ -169,8 +164,10 @@ class ScpInMemoryDataset(ScpDataset): It has the maximum memory usage and least I/O. """ - def __init__(self, path, utt2num_frames_path=None): - super().__init__(path, utt2num_frames_path) + def __init__( + self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, + ): + super().__init__(utt_ids, rxfiles, utt2num_frames=utt2num_frames) self.read_data() def read_data(self): @@ -182,7 +179,7 @@ def read_data(self): for i in range(len(self.data_offsets)): ptx = self.data_offsets[i] dst = self.buffer[ptx: ptx + self.sizes[i]] - np.copyto(dst, kaldi_io.read_mat(self.extended_filenames[i])) + np.copyto(dst, kaldi_io.read_mat(self.rxfiles[i])) def filter_and_reorder(self, indices): super().filter_and_reorder(indices) @@ -200,29 +197,26 @@ class AsrTextDataset(torch.utils.data.Dataset): Original lines are also kept in memory. Each line of the text file is in the format of 'utt_id tokenized_text'.""" - def __init__(self, path, dictionary, append_eos=True): + def __init__(self, utt_ids: List[str], token_text: List[str], dictionary, append_eos=True): super().__init__() self.dtype = np.float self.append_eos = append_eos - self.read_text(path, dictionary) + self.read_text(utt_ids, token_text, dictionary) - def read_text(self, path, dictionary): - self.utt_ids = [] - self.tokens_list = [] + def read_text(self, utt_ids: List[str], token_text: List[str], dictionary): + assert len(utt_ids) == len(token_text) + self.utt_ids = utt_ids + self.tokens_list = token_text self.tensor_list = [] + self.size = len(self.utt_ids) # number of utterances self.sizes = [] - with open(path, 'r', encoding='utf-8') as f: - for line in f: - utt_id, tokens = line.strip().split(None, 1) - self.utt_ids.append(utt_id) - self.tokens_list.append(tokens) - tensor = dictionary.encode_line( - tokens, add_if_not_exist=False, append_eos=self.append_eos, - ).long() - self.tensor_list.append(tensor) - self.sizes.append(len(self.tensor_list[-1])) + for tokens in self.tokens_list: + tensor = dictionary.encode_line( + tokens, add_if_not_exist=False, append_eos=self.append_eos, + ).long() + self.tensor_list.append(tensor) + self.sizes.append(len(self.tensor_list[-1])) - self.size = len(self.utt_ids) # number of utterances self.sizes = np.array(self.sizes, dtype=np.int32) assert len(self.utt_ids) == len(self.tokens_list) and \ diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index d025e7875..f9218a9c9 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -210,8 +210,7 @@ def _main(args, output_file): logger.info('Saved attention plots in ' + save_dir) if has_target: - assert args.test_text_files is not None - scorer.add_ordered_utt_list(*args.test_text_files) + scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids) fn = 'decoded_char_results.txt' with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: diff --git a/espresso/speech_train.py b/espresso/speech_train.py index d92a6064e..106bf8edd 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -10,6 +10,7 @@ import logging import math +import os import random import sys @@ -121,7 +122,7 @@ def main(args, init_distributed=False): epoch_itr = trainer.get_train_iterator( epoch_itr.epoch, # sharded data: get train iterator for next epoch - load_dataset=(len(args.train_feat_files) > 1), + load_dataset=(os.pathsep in getattr(args, 'data', '')), ) train_meter.stop() logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 14c0961cf..c6d5e8f3b 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -3,12 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import OrderedDict +import itertools +import json import logging import os import torch -from fairseq import metrics, options, search, utils +from fairseq import metrics, search, utils from fairseq.data import ConcatDataset from fairseq.tasks import FairseqTask, register_task @@ -24,13 +27,99 @@ logger = logging.getLogger(__name__) +def get_asr_dataset_from_json( + data_path, split, tgt_dict, + combine, upsample_primary, + max_source_positions, max_target_positions, +): + """ + Parse data json and create dataset. + See espresso/tools/asr_prep_json.py which pack json from raw files + Json example: + { + "011c0202": { + "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819", + "token_text": "T H E H O T E L", + "utt2num_frames": "693", + }, + "011c0203": { + ... + } + } + """ + src_datasets = [] + tgt_datasets = [] + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + data_json_path = os.path.join(data_path, "{}.json".format(split_k)) + if not os.path.isfile(data_json_path): + if k > 0: + break + else: + raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + + with open(data_json_path, "rb") as f: + loaded_json = json.load(f, object_pairs_hook=OrderedDict) + + utt_ids, feats, token_text, utt2num_frames = [], [], [], [] + for utt_id, val in loaded_json.items(): + utt_ids.append(utt_id) + feats.append(val["feat"]) + if "token_text" in val: + token_text.append(val["token_text"]) + if "utt2num_frames" in val: + utt2num_frames.append(int(val["utt2num_frames"])) + + assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) + src_datasets.append(ScpCachedDataset( + utt_ids, feats, utt2num_frames=utt2num_frames, ordered_prefetch=True + )) + if len(token_text) > 0: + assert len(utt_ids) == len(token_text) + assert tgt_dict is not None + tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict)) + + logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) + + if not combine: + break + + if len(tgt_datasets) > 0: + assert len(src_datasets) == len(tgt_datasets) + + feat_dim = src_datasets[0].feat_dim + + if len(src_datasets) == 1: + src_dataset = src_datasets[0] + tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None + else: + for i in range(1, len(src_datasets)): + assert feat_dim == src_datasets[i].feat_dim, \ + "feature dimension does not match across multiple json files" + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) \ + if len(tgt_datasets) > 0 else None + + return SpeechDataset( + src_dataset, src_dataset.sizes, + tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, + tgt_dict, + left_pad_source=False, + left_pad_target=False, + max_source_positions=max_source_positions, + max_target_positions=max_target_positions, + ) + + @register_task('speech_recognition_espresso') class SpeechRecognitionEspressoTask(FairseqTask): """ Transcribe from speech (source) to token text (target). Args: - dictionary (~fairseq.data.AsrDictionary): dictionary for the output tokens + tgt_dict (~fairseq.data.AsrDictionary): dictionary for the output tokens word_dict (~fairseq.data.AsrDictionary): dictionary for the words (for decoding with word-based LMs) feat_in_channels (int): input feature channels @@ -52,62 +141,24 @@ class SpeechRecognitionEspressoTask(FairseqTask): def add_args(parser): """Add task-specific arguments to the parser.""" # fmt: off - parser.add_argument('--train-feat-files', nargs='+', - help='path(s) to scp feature file(s) for training, ' - 'will be iterated upon during epochs in round-robin manner') - parser.add_argument('--train-text-files', nargs='+', - help='path(s) to text file(s) for training, where ' - 'each should matches with one in --train-feat-files, ' - 'will be iterated upon during epochs in round-robin manner') - parser.add_argument('--train-utt2num-frames-files', nargs='+', default=None, - help='path(s) to utt2num_frames file(s) for training. if not None, ' - 'each should matches with one in --train-feat-files, ' - 'will be iterated upon during epochs in round-robin manner') - parser.add_argument('--valid-feat-files', nargs='+', - help='path(s) to scp feature file(s) for validation') - parser.add_argument('--valid-text-files', nargs='+', - help='path(s) to text file(s) for validation, where ' - 'each should matches with one in --valid-feat-files') - parser.add_argument('--valid-utt2num-frames-files', nargs='+', default=None, - help='path(s) to utt2num_frames file(s) for validation. if not None, ' - 'each should matches with one in --valid-feat-files') - parser.add_argument('--test-feat-files', nargs='+', - help='path(s) to scp feature file(s) for test') - parser.add_argument('--test-text-files', nargs='+', default=None, - help='path(s) to text file(s) for test. if not None, ' - 'each one should matches with one in --test-feat-files') - parser.add_argument('--test-utt2num-frames-files', nargs='+', default=None, - help='path(s) to utt2num_frames file(s) for test. if not None, ' - 'each should matches with one in --test-feat-files') - parser.add_argument('--train-subset-feat-files', nargs='+', - help='path(s) to scp feature file(s) for validation') - parser.add_argument('--train-subset-text-files', nargs='+', - help='path(s) to text file(s) for validation, where ' - 'each should matches with one in --train-subset-feat-files') - parser.add_argument('--train-subset-utt2num-frames-files', nargs='+', default=None, - help='path(s) to utt2num_frames file(s) for validation. if not None, ' - 'each should matches with one in --train-subset-feat-files') - parser.add_argument('--dict', default=None, type=str, - help='path to the dictionary') - parser.add_argument('--non-lang-syms', default=None, type=str, - help='path to a file listing non-linguistic symbols, e.g., ' - 'etc. One entry per line. To be filtered out when calculating WER/CER.') - parser.add_argument('--word-dict', default=None, type=str, - help='path to the word dictionary. Only relevant for decoding') - parser.add_argument('--wer-output-filter', default=None, type=str, - help='path to wer_output_filter file for WER evaluation') - parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', - help='pad the source on the left') - parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', - help='pad the target on the left') - parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', - help='max number of frames in the source sequence') - parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', - help='max number of tokens in the target sequence') - parser.add_argument('--upsample-primary', default=1, type=int, - help='amount to upsample primary dataset') - parser.add_argument('--feat-in-channels', default=1, type=int, metavar='N', - help='feature input channels') + parser.add_argument("data", help="path to data directory") + parser.add_argument("--dict", default=None, type=str, + help="path to the dictionary") + parser.add_argument("--non-lang-syms", default=None, type=str, + help="path to a file listing non-linguistic symbols, e.g., " + "etc. One entry per line. To be filtered out when calculating WER/CER.") + parser.add_argument("--word-dict", default=None, type=str, + help="path to the word dictionary. Only relevant for decoding") + parser.add_argument("--wer-output-filter", default=None, type=str, + help="path to wer_output_filter file for WER evaluation") + parser.add_argument("--max-source-positions", default=1024, type=int, metavar="N", + help="max number of frames in the source sequence") + parser.add_argument("--max-target-positions", default=1024, type=int, metavar="N", + help="max number of tokens in the target sequence") + parser.add_argument("--upsample-primary", default=1, type=int, + help="amount to upsample primary dataset") + parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", + help="feature input channels") # fmt: off @classmethod @@ -125,9 +176,9 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding """ raise NotImplementedError - def __init__(self, args, dictionary, word_dict=None): + def __init__(self, args, tgt_dict, word_dict=None): super().__init__(args) - self.dictionary = dictionary + self.tgt_dict = tgt_dict self.word_dict = word_dict self.feat_in_channels = args.feat_in_channels torch.backends.cudnn.deterministic = True @@ -143,22 +194,17 @@ def setup_task(cls, args, **kwargs): Args: args (argparse.Namespace): parsed command-line arguments """ - args.left_pad_source = options.eval_bool(args.left_pad_source) - args.left_pad_target = options.eval_bool(args.left_pad_target) - # load dictionaries - dict_path = os.path.join(os.path.dirname(args.train_text_files[0]), 'dict.txt') \ - if args.dict is None and args.train_text_files is not None else args.dict - assert dict_path is not None, 'Please specify --dict' - dictionary = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) - logger.info('dictionary: {} types'.format(len(dictionary))) + dict_path = os.path.join(args.data, "dict.txt") if args.dict is None else args.dict + tgt_dict = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) + logger.info("dictionary: {} types".format(len(tgt_dict))) if args.word_dict is not None: word_dict = cls.load_dictionary(args.word_dict) - logger.info('word dictionary: {} types'.format(len(word_dict))) - return cls(args, dictionary, word_dict) + logger.info("word dictionary: {} types".format(len(word_dict))) + return cls(args, tgt_dict, word_dict) else: - return cls(args, dictionary) + return cls(args, tgt_dict) def load_dataset(self, split, epoch=0, combine=False, **kwargs): """Load a given dataset split. @@ -166,113 +212,47 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - src_datasets = [] - tgt_datasets = [] - - if split == 'train': - feat_files = self.args.train_feat_files - text_files = self.args.train_text_files - utt2num_frames_files = self.args.train_utt2num_frames_files # can be None - assert len(feat_files) > 0 and len(feat_files) == len(text_files) - assert utt2num_frames_files is None or len(feat_files) == len(utt2num_frames_files) - feat_files = [feat_files[epoch % len(feat_files)]] - text_files = [text_files[epoch % len(text_files)]] - if utt2num_frames_files is not None: - utt2num_frames_files = [utt2num_frames_files[epoch % len(utt2num_frames_files)]] - else: - utt2num_frames_files = [None] - elif split == 'valid': - feat_files = self.args.valid_feat_files - text_files = self.args.valid_text_files - utt2num_frames_files = self.args.valid_utt2num_frames_files # can be None - if utt2num_frames_files is None: - utt2num_frames_files = [None] * len(feat_files) - elif split == 'test': - feat_files = self.args.test_feat_files - text_files = self.args.test_text_files # can be None - utt2num_frames_files = self.args.test_utt2num_frames_files # can be None - if text_files is None: - text_files = [None] * len(feat_files) - if utt2num_frames_files is None: - utt2num_frames_files = [None] * len(feat_files) - elif split == 'train_subset': - feat_files = self.args.train_subset_feat_files - text_files = self.args.train_subset_text_files - utt2num_frames_files = self.args.train_subset_utt2num_frames_files # can be None - if utt2num_frames_files is None: - utt2num_frames_files = [None] * len(feat_files) - else: - raise ValueError('split should be one of "train", "valid", "test", "train_subset"') - - assert len(feat_files) > 0 and len(feat_files) == len(text_files) and \ - len(feat_files) == len(utt2num_frames_files) - file_tuples = zip(feat_files, text_files, utt2num_frames_files) - for feat, text, utt2num_frames in file_tuples: - assert ScpCachedDataset.exists(feat), feat + ' does not exists' - assert text is None or AsrTextDataset.exists(text), text + ' does not exists' - assert utt2num_frames is None or ScpCachedDataset.exists(utt2num_frames), \ - utt2num_frames + ' does not exists' - src_datasets.append(ScpCachedDataset(feat, utt2num_frames, ordered_prefetch=True)) - logger.info('{} {} examples'.format(feat, len(src_datasets[-1]))) - if text is not None: - tgt_datasets.append(AsrTextDataset(text, self.dictionary)) - logger.info('{} {} examples'.format(text, len(tgt_datasets[-1]))) - - if not combine: - break - - if len(tgt_datasets) > 0: - assert len(src_datasets) == len(tgt_datasets) - - self.feat_dim = src_datasets[0].feat_dim - - if len(src_datasets) == 1: - src_dataset = src_datasets[0] - tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None - else: - for i in range(1, len(src_datasets)): - assert self.feat_dim == src_datasets[i].feat_dim, \ - 'feature dimension does not match across multiple scp files' - sample_ratios = [1] * len(src_datasets) - sample_ratios[0] = self.args.upsample_primary - src_dataset = ConcatDataset(src_datasets, sample_ratios) - tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) \ - if len(tgt_datasets) > 0 else None - - self.datasets[split] = SpeechDataset( - src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, - self.dictionary, - left_pad_source=self.args.left_pad_source, - left_pad_target=self.args.left_pad_target, + paths = self.args.data.split(os.pathsep) + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + + self.datasets[split] = get_asr_dataset_from_json( + data_path, split, self.tgt_dict, + combine=combine, + upsample_primary=self.args.upsample_primary, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, ) - # update the counts of and in dictionary with training data - if split == 'train': - self.dictionary.count[self.dictionary.eos()] = len(tgt_dataset) + src_dataset = self.datasets[split].src + self.feat_dim = src_dataset.feat_dim if not isinstance(src_dataset, ConcatDataset) \ + else src_dataset.datasets[0].feat_dim + + # update the counts of and in tgt_dict with training data + if split == "train": + tgt_dataset = self.datasets[split].tgt + self.tgt_dict.count[self.tgt_dict.eos()] = len(tgt_dataset) unk_count = 0 for i in range(len(tgt_dataset)): - unk_count += (tgt_dataset[i][0] == self.dictionary.unk()).int().sum().item() - self.dictionary.count[self.dictionary.unk()] = unk_count + unk_count += (tgt_dataset[i][0] == self.tgt_dict.unk()).int().sum().item() + self.tgt_dict.count[self.tgt_dict.unk()] = unk_count def build_generator(self, args): if args.score_reference: args.score_reference = False logger.warning( - '--score-reference is not applicable to speech recognition, ignoring it.' + "--score-reference is not applicable to speech recognition, ignoring it." ) from fairseq.sequence_generator import SequenceGenerator # Choose search strategy. Defaults to Beam Search. - sampling = getattr(args, 'sampling', False) - sampling_topk = getattr(args, 'sampling_topk', -1) - sampling_topp = getattr(args, 'sampling_topp', -1.0) - diverse_beam_groups = getattr(args, 'diverse_beam_groups', -1) - diverse_beam_strength = getattr(args, 'diverse_beam_strength', 0.5), - match_source_len = getattr(args, 'match_source_len', False) - diversity_rate = getattr(args, 'diversity_rate', -1) + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5), + match_source_len = getattr(args, "match_source_len", False) + diversity_rate = getattr(args, "diversity_rate", -1) if ( sum( int(cond) @@ -285,9 +265,9 @@ def build_generator(self, args): ) > 1 ): - raise ValueError('Provided Search parameters are mutually exclusive.') - assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' - assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' + raise ValueError("Provided Search parameters are mutually exclusive.") + assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" + assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) @@ -308,18 +288,18 @@ def build_generator(self, args): return SequenceGenerator( self.target_dictionary, - beam_size=getattr(args, 'beam', 5), - max_len_a=getattr(args, 'max_len_a', 0), - max_len_b=getattr(args, 'max_len_b', 200), - min_len=getattr(args, 'min_len', 1), - normalize_scores=(not getattr(args, 'unnormalized', False)), - len_penalty=getattr(args, 'lenpen', 1), - unk_penalty=getattr(args, 'unkpen', 0), - temperature=getattr(args, 'temperature', 1.), - match_source_len=getattr(args, 'match_source_len', False), - no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, - eos_factor=getattr(args, 'eos_factor', None), + eos_factor=getattr(args, "eos_factor", None), ) def build_dataset_for_inference(self, src_tokens, src_lengths): @@ -334,8 +314,8 @@ def build_model(self, args): def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) ( - logging_output['word_error'], logging_output['word_count'], - logging_output['char_error'], logging_output['char_count'], + logging_output["word_error"], logging_output["word_count"], + logging_output["char_error"], logging_output["char_count"], ) = self._inference_with_wer(self.decoder_for_validation, sample, model) return loss, sample_size, logging_output @@ -347,14 +327,14 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, lm_weigh def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) - word_error = utils.item(sum(log.get('word_error', 0) for log in logging_outputs)) - word_count = utils.item(sum(log.get('word_count', 0) for log in logging_outputs)) - char_error = utils.item(sum(log.get('char_error', 0) for log in logging_outputs)) - char_count = utils.item(sum(log.get('char_count', 0) for log in logging_outputs)) + word_error = utils.item(sum(log.get("word_error", 0) for log in logging_outputs)) + word_count = utils.item(sum(log.get("word_count", 0) for log in logging_outputs)) + char_error = utils.item(sum(log.get("char_error", 0) for log in logging_outputs)) + char_count = utils.item(sum(log.get("char_count", 0) for log in logging_outputs)) if word_count > 0: - metrics.log_scalar('wer', float(word_error) / word_count * 100, word_count, round=4) + metrics.log_scalar("wer", float(word_error) / word_count * 100, word_count, round=4) if char_count > 0: - metrics.log_scalar('cer', float(char_error) / char_count * 100, char_count, round=4) + metrics.log_scalar("cer", float(char_error) / char_count * 100, char_count, round=4) def max_positions(self): """Return the max sentence length allowed by the task.""" @@ -363,7 +343,7 @@ def max_positions(self): @property def target_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" - return self.dictionary + return self.tgt_dict @property def word_dictionary(self): @@ -376,13 +356,13 @@ def _inference_with_wer(self, decoder, sample, model): scorer = wer.Scorer(self.target_dictionary, wer_output_filter=self.args.wer_output_filter) tokens, lprobs, _ = decoder.decode([model], sample) pred = tokens[:, 1:].data.cpu() # bsz x len - target = sample['target'] + target = sample["target"] assert pred.size(0) == target.size(0) # compute word error stats scorer.reset() for i in range(target.size(0)): - utt_id = sample['utt_id'][i] - ref_tokens = sample['target_raw_text'][i] + utt_id = sample["utt_id"][i] + ref_tokens = sample["target_raw_text"][i] pred_tokens = self.target_dictionary.string(pred.data[i]) scorer.add_evaluation( utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, diff --git a/espresso/tools/asr_prep_json.py b/espresso/tools/asr_prep_json.py new file mode 100755 index 000000000..f39755bfa --- /dev/null +++ b/espresso/tools/asr_prep_json.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from collections import OrderedDict +import json +import logging +import sys + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + stream=sys.stdout, +) +logger = logging.getLogger("espresso.tools.asr_prep_json") + + +def read_file(ordered_dict, key, dtype, *paths): + for path in paths: + with open(path, "r", encoding="utf-8") as f: + for line in f: + utt_id, val = line.strip().split(None, 1) + if utt_id in ordered_dict: + assert key not in ordered_dict[utt_id], \ + "Duplicate utterance id " + utt_id + " in " + key + ordered_dict[utt_id].update({key: dtype(val)}) + else: + ordered_dict[utt_id] = {key: val} + return ordered_dict + + +def main(): + parser = argparse.ArgumentParser( + description="Wrap all related files of a dataset into a single json file" + ) + # fmt: off + parser.add_argument("--feat-files", nargs="+", required=True, + help="path(s) to scp feature file(s)") + parser.add_argument("--token-text-files", nargs="+", default=None, + help="path(s) to token_text file(s)") + parser.add_argument("--utt2num-frames-files", nargs="+", default=None, + help="path(s) to utt2num_frames file(s)") + parser.add_argument("--output", required=True, type=argparse.FileType("w"), + help="path to save json output") + args = parser.parse_args() + # fmt: on + + obj = OrderedDict() + obj = read_file(obj, "feat", str, *(args.feat_files)) + if args.token_text_files is not None: + obj = read_file(obj, "token_text", str, *(args.token_text_files)) + if args.utt2num_frames_files is not None: + obj = read_file(obj, "utt2num_frames", int, *(args.utt2num_frames_files)) + + json.dump(obj, args.output, indent=4) + logger.info("Dumped {} examples in {}".format(len(obj), args.output.name)) + + +if __name__ == "__main__": + main() diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index c5190d247..353a80cb3 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -138,6 +138,9 @@ def tot_char_count(self): return self.char_counter['words'] def add_ordered_utt_list(self, *args): + if len(args) == 1 and isinstance(args[0], list): # aleady a list of utterance ids + self.ordered_utt_list = args[0] + return self.ordered_utt_list = [] for text_file in args: with open(text_file, 'r', encoding='utf-8') as f: diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index a09335a0b..a9ac7ef18 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -195,19 +195,31 @@ if [ ${stage} -le 6 ]; then done fi -train_feat=$train_feat_dir/feats.scp -train_token_text=data/$train_set/token_text -train_utt2num_frames=data/$train_set/utt2num_frames -valid_feat=$valid_feat_dir/feats.scp -valid_token_text=data/$valid_set/token_text -valid_utt2num_frames=data/$valid_set/utt2num_frames if [ ${stage} -le 7 ]; then - echo "Stage 7: Model Training" + echo "Stage 7: Dump Json Files" + train_feat=$train_feat_dir/feats.scp + train_token_text=data/$train_set/token_text + train_utt2num_frames=data/$train_set/utt2num_frames + valid_feat=$valid_feat_dir/feats.scp + valid_token_text=data/$valid_set/token_text + valid_utt2num_frames=data/$valid_set/utt2num_frames + asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json + asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json + for dataset in $test_set; do + feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp + token_text=data/$dataset/token_text + utt2num_frames=data/$dataset/utt2num_frames + asr_prep_json.py --feat-files $feat --token-text-files $token_text --utt2num-frames-files $utt2num_frames --output data/$dataset.json + done +fi + +if [ ${stage} -le 8 ]; then + echo "Stage 8: Model Training" valid_subset=valid mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ @@ -219,14 +231,12 @@ if [ ${stage} -le 7 ]; then --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ - --train-feat-files $train_feat --train-text-files $train_token_text --train-utt2num-frames-files $train_utt2num_frames \ - --valid-feat-files $valid_feat --valid-text-files $valid_token_text --valid-utt2num-frames-files $valid_utt2num_frames \ --dict $dict --remove-bpe sentencepiece \ --max-source-positions 9999 --max-target-positions 999 2>&1 | tee $log_file fi -if [ ${stage} -le 8 ]; then - echo "Stage 8: Decoding" +if [ ${stage} -le 9 ]; then + echo "Stage 9: Decoding" opts="" path=$dir/$checkpoint decode_affix= @@ -236,14 +246,10 @@ if [ ${stage} -le 8 ]; then decode_affix=shallow_fusion fi for dataset in $test_set; do - feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp - text=data/$dataset/token_text - utt2num_frames=data/$dataset/utt2num_frames decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} - CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ + CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --max-sentences 24 \ - --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text --test-utt2num-frames-files $utt2num_frames \ - --dict $dict --remove-bpe sentencepiece \ + --num-shards 1 --shard-id 0 --dict $dict --remove-bpe sentencepiece --gen-subset $dataset \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index db3bc8ee5..409a4eea1 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -233,21 +233,34 @@ if [ $stage -le 5 ]; then done fi -train_feat=$train_feat_dir/feats.scp -train_token_text=data/$train_set/token_text -train_utt2num_frames=data/$train_set/utt2num_frames -valid_feat=$valid_feat_dir/feats.scp -valid_token_text=data/$valid_set/token_text -valid_utt2num_frames=data/$valid_set/utt2num_frames if [ $stage -le 6 ]; then - echo "Stage 6: Model Training" + echo "Stage 6: Dump Json Files" + train_feat=$train_feat_dir/feats.scp + train_token_text=data/$train_set/token_text + train_utt2num_frames=data/$train_set/utt2num_frames + valid_feat=$valid_feat_dir/feats.scp + valid_token_text=data/$valid_set/token_text + valid_utt2num_frames=data/$valid_set/utt2num_frames + asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json + asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json + for dataset in $test_set; do + feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp + utt2num_frames=data/$dataset/utt2num_frames + # only score train_dev with built-in scorer + text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--token-text-files data/$dataset/token_text" + asr_prep_json.py --feat-files $feat $text_opt --utt2num-frames-files $utt2num_frames --output data/$dataset.json + done +fi + +if [ $stage -le 7 ]; then + echo "Stage 7: Model Training" valid_subset=valid opts="" [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -259,14 +272,12 @@ if [ $stage -le 6 ]; then --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6 \ - --train-feat-files $train_feat --train-text-files $train_token_text --train-utt2num-frames-files $train_utt2num_frames \ - --valid-feat-files $valid_feat --valid-text-files $valid_token_text --valid-utt2num-frames-files $valid_utt2num_frames \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi -if [ $stage -le 7 ]; then - echo "Stage 7: Decoding" +if [ $stage -le 8 ]; then + echo "Stage 8: Decoding" [ ! -d $KALDI_ROOT ] && echo "Expected $KALDI_ROOT to exist" && exit 1; opts="" path=$dir/$checkpoint @@ -278,15 +289,9 @@ if [ $stage -le 7 ]; then fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $test_set; do - feat=${dumpdir}/$dataset/delta${do_delta}/feats.scp - utt2num_frames=data/$dataset/utt2num_frames - decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} - # only score train_dev with built-in scorer - text_opt= && [ "$dataset" == "train_dev" ] && text_opt="--test-text-files data/$dataset/token_text" - CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ + CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ - --num-shards 1 --shard-id 0 --test-feat-files $feat $text_opt --test-utt2num-frames-files $utt2num_frames \ - --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ + --num-shards 1 --shard-id 0 --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms --gen-subset $dataset \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 20b2923ff..dd21c1058 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -247,27 +247,44 @@ if [ ${stage} -le 7 ] && $use_wordlm; then done fi -train_feat=$train_feat_dir/feats.scp -train_token_text=data/$train_set/token_text -train_utt2num_frames=data/$train_set/utt2num_frames -valid_feat=$valid_feat_dir/feats.scp -valid_token_text=data/$valid_set/token_text -valid_utt2num_frames=data/$valid_set/utt2num_frames if [ ${stage} -le 8 ]; then - echo "Stage 8: Model Training" + echo "Stage 8: Dump Json Files" + train_feat=$train_feat_dir/feats.scp + train_token_text=data/$train_set/token_text + train_utt2num_frames=data/$train_set/utt2num_frames + valid_feat=$valid_feat_dir/feats.scp + valid_token_text=data/$valid_set/token_text + valid_utt2num_frames=data/$valid_set/utt2num_frames + train_subset_feat=$train_subset_feat_dir/feats.scp + train_subset_token_text=data/${train_set}_${train_subset_size}/token_text + train_subset_utt2num_frames=data/${train_set}_${train_subset_size}/utt2num_frames + asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json + asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json + asr_prep_json.py --feat-files $train_subset_feat --token-text-files $train_subset_token_text --utt2num-frames-files $train_subset_utt2num_frames --output data/train_subset.json + for dataset in $valid_set $test_set; do + if [ "$dataset" == "$valid_set" ]; then + feat=$valid_feat_dir/feats.scp + elif [ "$dataset" == "$test_set" ]; then + feat=$test_feat_dir/feats.scp + fi + token_text=data/$dataset/token_text + utt2num_frames=data/$dataset/utt2num_frames + asr_prep_json.py --feat-files $feat --token-text-files $token_text --utt2num-frames-files $utt2num_frames --output data/$dataset.json + done +fi + +if [ ${stage} -le 9 ]; then + echo "Stage 9: Model Training" opts="" valid_subset=valid if $validate_on_train_subset; then valid_subset="$valid_subset,train_subset" - opts="$opts --train-subset-feat-files $train_subset_feat_dir/feats.scp" - opts="$opts --train-subset-text-files data/${train_set}_${train_subset_size}/token_text" - opts="$opts --train-subset-utt2num-frames-files data/${train_set}_${train_subset_size}/utt2num_frames" fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ @@ -279,14 +296,12 @@ if [ ${stage} -le 8 ]; then --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.05 --smoothing-type temporal \ --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6 \ - --train-feat-files $train_feat --train-text-files $train_token_text --train-utt2num-frames-files $train_utt2num_frames \ - --valid-feat-files $valid_feat --valid-text-files $valid_token_text --valid-utt2num-frames-files $valid_utt2num_frames \ --dict $dict --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi -if [ ${stage} -le 9 ]; then - echo "Stage 9: Decoding" +if [ ${stage} -le 10 ]; then + echo "Stage 10: Decoding" opts="" path=$dir/$checkpoint decode_affix= @@ -303,18 +318,10 @@ if [ ${stage} -le 9 ]; then fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $valid_set $test_set; do - if [ "$dataset" == "$valid_set" ]; then - feat=$valid_feat_dir/feats.scp - elif [ "$dataset" == "$test_set" ]; then - feat=$test_feat_dir/feats.scp - fi - text=data/$dataset/token_text - utt2num_frames=data/$dataset/utt2num_frames decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} - CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py \ + CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --max-sentences 32 \ - --num-shards 1 --shard-id 0 --test-feat-files $feat --test-text-files $text --test-utt2num-frames-files $utt2num_frames \ - --dict $dict --non-lang-syms $nlsyms \ + --num-shards 1 --shard-id 0 --dict $dict --non-lang-syms $nlsyms --gen-subset $dataset \ --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts --print-alignment diff --git a/tests/espresso/test_speech_dataset.py b/tests/espresso/test_speech_dataset.py index 79818cc8c..cae61ce50 100644 --- a/tests/espresso/test_speech_dataset.py +++ b/tests/espresso/test_speech_dataset.py @@ -41,47 +41,46 @@ def make_dictionary(): @staticmethod def generate_feats(test_dir, num=10, seed=0): """generate feature matrices.""" - feats = {} + expected_feats = {} np.random.seed(seed) - with open( - os.path.join(test_dir, 'feats.scp'), 'w', encoding='utf-8', - ) as f: - for i in range(num): - utt_id = 'utt_id_' + str(i) - ark_file = os.path.join(test_dir, 'mat_' + str(i) + '.ark') - f.write(utt_id + ' ' + ark_file + ':0\n') - length = np.random.randint(200, 800) - m = np.random.uniform(-10.0, 10.0, (length, 40)) - feats[utt_id] = m - kaldi_io.write_mat(ark_file, m) - return feats + utt_ids, rxfiles, utt2num_frames = [], [], [] + for i in range(num): + utt_id = 'utt_id_' + str(i) + ark_file = os.path.join(test_dir, 'mat_' + str(i) + '.ark') + length = np.random.randint(200, 800) + m = np.random.uniform(-10.0, 10.0, (length, 40)) + expected_feats[utt_id] = m + kaldi_io.write_mat(ark_file, m) + utt_ids.append(utt_id) + rxfiles.append(ark_file + ':0') + utt2num_frames.append(length) + return expected_feats, utt_ids, rxfiles, utt2num_frames @staticmethod - def generate_text_tokens(test_dir, num=10, seed=0): + def generate_text(test_dir, num=10, seed=0): """generate token text, where utterances are in a (random) different order from those in feats.scp.""" - text_tokens = {} + expected_text = {} alphabet = string.ascii_lowercase space = '' vocab = list(alphabet) vocab.append(space) np.random.seed(seed) - with open( - os.path.join(test_dir, 'text_tokens'), 'w', encoding='utf-8', - ) as f: - for i in np.random.permutation(range(num)): - utt_id = 'utt_id_' + str(i) - length = np.random.randint(10, 100) - tokens = [ - vocab[np.random.randint(0, len(vocab))] for _ in range(length) - ] - if tokens[0] == space: - tokens[0] = vocab[np.random.randint(0, len(vocab) - 1)] - if tokens[-1] == space: - tokens[-1] = vocab[np.random.randint(0, len(vocab) - 1)] - text_tokens[utt_id] = tokens - f.write(utt_id + ' ' + ' '.join(tokens) + '\n') - return text_tokens + utt_ids, token_text = [], [] + for i in np.random.permutation(range(num)): + utt_id = 'utt_id_' + str(i) + length = np.random.randint(10, 100) + tokens = [ + vocab[np.random.randint(0, len(vocab))] for _ in range(length) + ] + if tokens[0] == space: + tokens[0] = vocab[np.random.randint(0, len(vocab) - 1)] + if tokens[-1] == space: + tokens[-1] = vocab[np.random.randint(0, len(vocab) - 1)] + expected_text[utt_id] = tokens + utt_ids.append(utt_id) + token_text.append(' '.join(tokens)) + return expected_text, utt_ids, token_text def setUp(self): self.test_dir = './temp' @@ -91,30 +90,35 @@ def setUp(self): self.batch_size = 8 self.cache_size = 16 self.dictionary = self.make_dictionary() - self.expected_feats = self.generate_feats( - self.test_dir, num=self.num_audios, seed=0, - ) - self.expected_tokens = self.generate_text_tokens( - self.test_dir, num=self.num_transripts, seed=1, - ) + ( + self.expected_feats, self.feats_utt_ids, self.rxfiles, self.utt2num_frames + ) = self.generate_feats(self.test_dir, num=self.num_audios, seed=0) + ( + self.expected_text, self.text_utt_ids, self.token_text + ) = self.generate_text(self.test_dir, num=self.num_transripts, seed=1) self.cuda = torch.cuda.is_available() def _speech_dataset_helper( - self, all_in_memory=False, ordered_prefetch=False, + self, all_in_memory=False, ordered_prefetch=False, has_utt2num_frames=False, ): if not all_in_memory: src_dataset = ScpCachedDataset( - path=os.path.join(self.test_dir, 'feats.scp'), + utt_ids=self.feats_utt_ids, + rxfiles=self.rxfiles, + utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, ordered_prefetch=ordered_prefetch, cache_size=self.cache_size, ) else: src_dataset = ScpInMemoryDataset( - path=os.path.join(self.test_dir, 'feats.scp') + utt_ids=self.feats_utt_ids, + rxfiles=self.rxfiles, + utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, ) tgt_dataset = AsrTextDataset( - path=os.path.join(self.test_dir, 'text_tokens'), + utt_ids=self.text_utt_ids, + token_text=self.token_text, dictionary=self.dictionary, ) @@ -162,7 +166,7 @@ def _speech_dataset_helper( src_frames[j, :src_lengths[j], :] ) self.assertEqual( - self.expected_tokens[utt_id], + self.expected_text[utt_id], tgt_tokens[j], ) @@ -175,6 +179,9 @@ def test_speech_dataset_cached_with_ordered_prefetch(self): def test_speech_dataset_all_in_memory(self): self._speech_dataset_helper(all_in_memory=True) + def test_speech_dataset_has_utt2num_frames(self): + self._speech_dataset_helper(has_utt2num_frames=True) + def assertTensorEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") if ( From ea7732b16951ccf36ff0d303fa2cf9501b6a5792 Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 22 Feb 2020 15:16:44 -0500 Subject: [PATCH 069/119] code adaptation/changes according to the commits on Feb 21, 2020 --- espresso/data/scp_text_dataset.py | 2 +- espresso/tasks/language_modeling_for_asr.py | 4 ++-- espresso/tasks/speech_recognition.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/espresso/data/scp_text_dataset.py b/espresso/data/scp_text_dataset.py index c81ddea25..d39d23914 100644 --- a/espresso/data/scp_text_dataset.py +++ b/espresso/data/scp_text_dataset.py @@ -97,7 +97,7 @@ def __init__( self.cache = None self.cache_index = {} self.cache_size = cache_size # in terms of number of examples - self.start_search_for_next_pos_start = 0 + self.start_pos_for_next_cache = 0 self.ordered_indices = list(range(self.size)) # set to True ONLY if examples are queried in the same order as # self.ordered_indices, and doing this will speed up search of the diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 395f18158..423430ad9 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -8,7 +8,7 @@ import torch -from fairseq import tokenizer +from fairseq import tokenizer, utils from fairseq.data import TruncatedDictionary from fairseq.tasks import register_task from fairseq.tasks.language_modeling import LanguageModelingTask @@ -105,7 +105,7 @@ def setup_task(cls, args, **kwargs): dictionary = None output_dictionary = None if args.data: - paths = args.data.split(os.pathsep) + paths = utils.split_paths(args.data) assert len(paths) > 0 dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \ else args.dict diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index c6d5e8f3b..00fb017d5 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -212,7 +212,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = self.args.data.split(os.pathsep) + paths = utils.split_paths(self.args.data) assert len(paths) > 0 data_path = paths[epoch % len(paths)] From 7170edcf3c40523adc62bdeab1a808d0749576cb Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 24 Feb 2020 21:38:18 -0500 Subject: [PATCH 070/119] move duplicated network parsers to espresso/speech_tools/utils.py; remove remaining coverage option passed in wsj recipe --- espresso/models/speech_fconv.py | 23 +++-------------------- espresso/models/speech_lstm.py | 27 ++++++--------------------- espresso/models/speech_transformer.py | 23 +++-------------------- espresso/tasks/speech_recognition.py | 12 +++++++----- espresso/tools/utils.py | 18 ++++++++++++++++++ examples/asr_wsj/run.sh | 2 +- 6 files changed, 38 insertions(+), 67 deletions(-) diff --git a/espresso/models/speech_fconv.py b/espresso/models/speech_fconv.py index 3f119317f..287e8ca92 100644 --- a/espresso/models/speech_fconv.py +++ b/espresso/models/speech_fconv.py @@ -81,26 +81,9 @@ def build_model(cls, args, task): decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path) utils.print_embed_overlap(decoder_embed_dict, task.target_dictionary) - def eval_str_nested_list_or_tuple(x, type=int): - if x is None: - return None - if isinstance(x, str): - x = eval(x) - if isinstance(x, list): - return list( - map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) - elif isinstance(x, tuple): - return tuple( - map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) - else: - try: - return type(x) - except TypeError: - raise TypeError - - out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) - kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) - strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index c2c828fcc..0dddb02dc 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -161,26 +161,9 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.decoder_freeze_embed: pretrained_decoder_embed.weight.requires_grad = False - def eval_str_nested_list_or_tuple(x, type=int): - if x is None: - return None - if isinstance(x, str): - x = eval(x) - if isinstance(x, list): - return list( - map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) - elif isinstance(x, tuple): - return tuple( - map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) - else: - try: - return type(x) - except TypeError: - raise TypeError - - out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) - kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) - strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( @@ -198,6 +181,8 @@ def eval_str_nested_list_or_tuple(x, type=int): s = stride rnn_encoder_input_size = (rnn_encoder_input_size + s - 1) // s rnn_encoder_input_size *= out_channels[-1] + else: + rnn_encoder_input_size = task.feat_dim scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler( args.scheduled_sampling_probs, args.start_scheduled_sampling_epoch, @@ -320,7 +305,7 @@ class SpeechLSTMEncoder(FairseqEncoder): def __init__( self, conv_layers_before=None, input_size=83, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - residual=False, left_pad=False, pretrained_embed=None, padding_value=0., + residual=False, left_pad=False, padding_value=0., max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__(None) # no src dictionary diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 75419585d..55eb883eb 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -98,26 +98,9 @@ def build_embedding(dictionary, embed_dim, path=None): dict, args.decoder_embed_dim, args.decoder_embed_path ) - def eval_str_nested_list_or_tuple(x, type=int): - if x is None: - return None - if isinstance(x, str): - x = eval(x) - if isinstance(x, list): - return list( - map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) - elif isinstance(x, tuple): - return tuple( - map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) - else: - try: - return type(x) - except TypeError: - raise TypeError - - out_channels = eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) - kernel_sizes = eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) - strides = eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 00fb017d5..40483d34f 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -84,8 +84,7 @@ def get_asr_dataset_from_json( if not combine: break - if len(tgt_datasets) > 0: - assert len(src_datasets) == len(tgt_datasets) + assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 feat_dim = src_datasets[0].feat_dim @@ -99,12 +98,15 @@ def get_asr_dataset_from_json( sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) - tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) \ - if len(tgt_datasets) > 0 else None + if len(tgt_datasets) > 0: + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + else: + tgt_dataset = None + tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return SpeechDataset( src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset.sizes if tgt_dataset is not None else None, + tgt_dataset, tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index bf2cfbeaa..3de35d7f8 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -92,6 +92,24 @@ def convert_padding_direction( return src_frames.gather(1, index) +def eval_str_nested_list_or_tuple(x, type=int): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + if isinstance(x, list): + return list( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + elif isinstance(x, tuple): + return tuple( + map(lambda s: eval_str_nested_list_or_tuple(s, type), x)) + else: + try: + return type(x) + except TypeError: + raise TypeError + + def plot_attention(attention, hypo_sent, utt_id, save_dir): """This function plots the attention for an example and save the plot in save_dir with .pdf as its filename. diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index dd21c1058..d7f780013 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -308,7 +308,7 @@ if [ ${stage} -le 10 ]; then if $lm_shallow_fusion; then if ! $use_wordlm; then path="$path:$lmdir/$lm_checkpoint" - opts="$opts --lm-weight 0.7 --coverage-weight 0.01 --eos-factor 1.5" + opts="$opts --lm-weight 0.7 --eos-factor 1.5" decode_affix=shallow_fusion else path="$path:$wordlmdir/$lm_checkpoint" From 95fcd904abac382fabecd405bc3f309e910ebcb1 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 27 Feb 2020 18:07:19 -0500 Subject: [PATCH 071/119] code adaptation/changes according to the commits on Feb 27-29, 2020 --- espresso/criterions/cross_entropy_v2.py | 43 +++--- .../label_smoothed_cross_entropy_v2.py | 73 +++++---- espresso/models/speech_fconv.py | 10 +- espresso/models/speech_lstm.py | 5 + espresso/models/speech_transformer.py | 10 +- espresso/speech_recognize.py | 138 +++++++++--------- espresso/speech_train.py | 36 +++-- espresso/tasks/speech_recognition.py | 28 +++- 8 files changed, 192 insertions(+), 151 deletions(-) diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py index 95b4a823c..349e0d08d 100644 --- a/espresso/criterions/cross_entropy_v2.py +++ b/espresso/criterions/cross_entropy_v2.py @@ -17,24 +17,23 @@ logger = logging.getLogger(__name__) -@register_criterion('cross_entropy_v2') +@register_criterion("cross_entropy_v2") class CrossEntropyV2Criterion(CrossEntropyCriterion): def __init__(self, args, task): super().__init__(args, task) self.dictionary = task.target_dictionary - self.num_updates = -1 self.epoch = 0 @staticmethod def add_args(parser): """Add criterion-specific arguments to the parser.""" # fmt: off - parser.add_argument('--print-training-sample-interval', type=int, - metavar='N', dest='print_interval', default=500, - help='print a training sample (reference + ' - 'prediction) every this number of updates') + parser.add_argument("--print-training-sample-interval", type=int, + metavar="N", dest="print_interval", default=500, + help="print a training sample (reference + " + "prediction) every this number of updates") # fmt: on def forward(self, model, sample, reduce=True): @@ -46,26 +45,27 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input'], epoch=self.epoch) + net_output = model(**sample["net_input"], epoch=self.epoch) loss, _, lprobs = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + sample_size = sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] logging_output = { - 'loss': loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['target'].size(0), - 'sample_size': sample_size, + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, } if ( - model.training and self.num_updates // self.args.print_interval > - (self.num_updates - 1) // self.args.print_interval + hasattr(model, "num_updates") and model.training and + model.num_updates // self.args.print_interval > + (model.num_updates - 1) // self.args.print_interval ): # print a randomly sampled result every print_interval updates target = model.get_targets(sample, net_output) pred = lprobs.argmax(-1).cpu() # bsz x len assert pred.size() == target.size() - with data_utils.numpy_seed(self.num_updates): - i = np.random.randint(0, len(sample['id'])) - ref_tokens = sample['target_raw_text'][i] + with data_utils.numpy_seed(model.num_updates): + i = np.random.randint(0, len(sample["id"])) + ref_tokens = sample["target_raw_text"][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) ref_one = self.dictionary.tokens_to_sentence( ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, @@ -74,8 +74,8 @@ def forward(self, model, sample, reduce=True): self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, bpe_symbol=self.args.remove_bpe, ) - logger.info('sample REF: ' + ref_one) - logger.info('sample PRD: ' + pred_one) + logger.info("sample REF: " + ref_one) + logger.info("sample PRD: " + pred_one) return loss, sample_size, logging_output @@ -86,12 +86,9 @@ def compute_loss(self, model, net_output, sample, reduce=True): lprobs.view(-1, lprobs.size(-1)), target.view(-1), ignore_index=self.padding_idx, - reduction='sum' if reduce else 'none', + reduction="sum" if reduce else "none", ) return loss, loss, lprobs - def set_num_updates(self, num_updates): - self.num_updates = num_updates - def set_epoch(self, epoch): self.epoch = epoch diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index a441ab48c..8e3683339 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -46,21 +46,21 @@ def temporal_label_smoothing_prob_mask( def label_smoothed_nll_loss( lprobs, target, epsilon, ignore_index=None, reduce=True, - smoothing_type='uniform', prob_mask=None, unigram_tensor=None, + smoothing_type="uniform", prob_mask=None, unigram_tensor=None, ): if target.dim() == lprobs.dim() - 1: target = target.unsqueeze(-1) nll_loss = -lprobs.gather(dim=-1, index=target) - if smoothing_type == 'temporal': + if smoothing_type == "temporal": assert torch.is_tensor(prob_mask) smooth_loss = -lprobs.mul(prob_mask).sum(-1, keepdim=True) - elif smoothing_type == 'unigram': + elif smoothing_type == "unigram": assert torch.is_tensor(unigram_tensor) smooth_loss = -lprobs.matmul(unigram_tensor.to(lprobs)) - elif smoothing_type == 'uniform': + elif smoothing_type == "uniform": smooth_loss = -lprobs.sum(dim=-1, keepdim=True) else: - raise ValueError('Unsupported smoothing type: {}'.format(smoothing_type)) + raise ValueError("Unsupported smoothing type: {}".format(smoothing_type)) if ignore_index is not None: pad_mask = target.eq(ignore_index) if pad_mask.any(): @@ -72,22 +72,21 @@ def label_smoothed_nll_loss( if reduce: nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() - eps_i = epsilon / lprobs.size(-1) if smoothing_type == 'uniform' else epsilon + eps_i = epsilon / lprobs.size(-1) if smoothing_type == "uniform" else epsilon loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss return loss, nll_loss -@register_criterion('label_smoothed_cross_entropy_v2') +@register_criterion("label_smoothed_cross_entropy_v2") class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): def __init__(self, args, task): super().__init__(args, task) self.dictionary = task.target_dictionary - self.num_updates = -1 self.epoch = 0 self.unigram_tensor = None - if args.smoothing_type == 'unigram': + if args.smoothing_type == "unigram": self.unigram_tensor = torch.cuda.FloatTensor(self.dictionary.count).unsqueeze(-1) \ if torch.cuda.is_available() and not args.cpu \ else torch.FloatTensor(self.dictionary.count).unsqueeze(-1) @@ -99,16 +98,16 @@ def add_args(parser): """Add criterion-specific arguments to the parser.""" # fmt: off LabelSmoothedCrossEntropyCriterion.add_args(parser) - parser.add_argument('--print-training-sample-interval', type=int, - metavar='N', dest='print_interval', default=500, - help='print a training sample (reference + ' - 'prediction) every this number of updates') - parser.add_argument('--smoothing-type', type=str, default='uniform', - choices=['uniform', 'unigram', 'temporal'], - help='label smoothing type. Default: uniform') - parser.add_argument('--unigram-pseudo-count', type=float, default=1.0, - metavar='C', help='pseudo count for unigram label ' - 'smoothing. Only relevant if --smoothing-type=unigram') + parser.add_argument("--print-training-sample-interval", type=int, + metavar="N", dest="print_interval", default=500, + help="print a training sample (reference + " + "prediction) every this number of updates") + parser.add_argument("--smoothing-type", type=str, default="uniform", + choices=["uniform", "unigram", "temporal"], + help="label smoothing type. Default: uniform") + parser.add_argument("--unigram-pseudo-count", type=float, default=1.0, + metavar="C", help="pseudo count for unigram label " + "smoothing. Only relevant if --smoothing-type=unigram") # fmt: on def forward(self, model, sample, reduce=True): @@ -120,29 +119,30 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample['net_input'], epoch=self.epoch) + net_output = model(**sample["net_input"], epoch=self.epoch) loss, nll_loss, lprobs = self.compute_loss( model, net_output, sample, reduce=reduce, smoothing_type=self.args.smoothing_type ) - sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] + sample_size = sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] logging_output = { - 'loss': loss.data, - 'nll_loss': nll_loss.data, - 'ntokens': sample['ntokens'], - 'nsentences': sample['target'].size(0), - 'sample_size': sample_size, + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, } if ( - model.training and self.num_updates // self.args.print_interval > - (self.num_updates - 1) // self.args.print_interval + hasattr(model, "num_updates") and model.training and + model.num_updates // self.args.print_interval > + (model.num_updates - 1) // self.args.print_interval ): # print a randomly sampled result every print_interval updates target = model.get_targets(sample, net_output) pred = lprobs.argmax(-1).cpu() # bsz x len assert pred.size() == target.size() - with data_utils.numpy_seed(self.num_updates): - i = np.random.randint(0, len(sample['id'])) - ref_tokens = sample['target_raw_text'][i] + with data_utils.numpy_seed(model.num_updates): + i = np.random.randint(0, len(sample["id"])) + ref_tokens = sample["target_raw_text"][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) ref_one = self.dictionary.tokens_to_sentence( ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, @@ -151,19 +151,19 @@ def forward(self, model, sample, reduce=True): self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, bpe_symbol=self.args.remove_bpe, ) - logger.info('sample REF: ' + ref_one) - logger.info('sample PRD: ' + pred_one) + logger.info("sample REF: " + ref_one) + logger.info("sample PRD: " + pred_one) return loss, sample_size, logging_output def compute_loss( - self, model, net_output, sample, reduce=True, smoothing_type='uniform' + self, model, net_output, sample, reduce=True, smoothing_type="uniform" ): lprobs = model.get_normalized_probs(net_output, log_probs=True) target = model.get_targets(sample, net_output) prob_mask = temporal_label_smoothing_prob_mask( lprobs, target, padding_index=self.padding_idx, - ) if smoothing_type == 'temporal' else None + ) if smoothing_type == "temporal" else None loss, nll_loss = label_smoothed_nll_loss( lprobs.view(-1, lprobs.size(-1)), target.view(-1, 1), self.eps, ignore_index=self.padding_idx, reduce=reduce, @@ -172,8 +172,5 @@ def compute_loss( ) return loss, nll_loss, lprobs - def set_num_updates(self, num_updates): - self.num_updates = num_updates - def set_epoch(self, epoch): self.epoch = epoch diff --git a/espresso/models/speech_fconv.py b/espresso/models/speech_fconv.py index 287e8ca92..002ee56bb 100644 --- a/espresso/models/speech_fconv.py +++ b/espresso/models/speech_fconv.py @@ -55,6 +55,10 @@ class SpeechFConvModel(FConvModel): def hub_models(cls): raise NotImplementedError + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + self.num_updates = 0 + @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" @@ -121,7 +125,11 @@ def build_model(cls, args, task): share_embed=args.share_input_output_embed, positional_embeddings=args.decoder_positional_embed, ) - return SpeechFConvModel(encoder, decoder) + return cls(encoder, decoder) + + def set_num_updates(self, num_updates): + self.num_updates = num_updates + super().set_num_updates(num_updates) class SpeechFConvEncoder(FConvEncoder): diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 0dddb02dc..98e5c7096 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -42,6 +42,7 @@ class SpeechLSTMModel(FairseqEncoderDecoderModel): def __init__(self, encoder, decoder, pretrained_lm=None): super().__init__(encoder, decoder) + self.num_updates = 0 self.pretrained_lm = pretrained_lm if pretrained_lm is not None: assert isinstance(self.pretrained_lm, FairseqDecoder) @@ -232,6 +233,10 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): param.requires_grad = False return cls(encoder, decoder, pretrained_lm) + def set_num_updates(self, num_updates): + self.num_updates = num_updates + super().set_num_updates(num_updates) + def max_positions(self): """Maximum length supported by the model.""" return ( diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 55eb883eb..cf4f4ddcc 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -57,6 +57,10 @@ class SpeechTransformerModel(TransformerModel): def hub_models(cls): raise NotImplementedError + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) + self.num_updates = 0 + @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" @@ -124,7 +128,11 @@ def build_embedding(dictionary, embed_dim, path=None): args, conv_layers_before=conv_layers, input_size=transformer_encoder_input_size, ) decoder = cls.build_decoder(args, dict, decoder_embed_tokens) - return SpeechTransformerModel(encoder, decoder) + return cls(encoder, decoder) + + def set_num_updates(self, num_updates): + self.num_updates = num_updates + super().set_num_updates(num_updates) @classmethod def build_encoder(cls, args, conv_layers_before=None, input_size=83): diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index f9218a9c9..36941d4d7 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -15,8 +15,9 @@ import torch -from fairseq import checkpoint_utils, options, progress_bar, tasks, utils -from fairseq.meters import StopwatchMeter, TimeMeter +from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel from espresso.models.external_language_model import MultiLevelLanguageModel @@ -125,6 +126,12 @@ def _main(args, output_file): shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + default_log_format=('tqdm' if not args.no_progress_bar else 'none'), + ) # Initialize generator if args.match_source_len: @@ -138,70 +145,69 @@ def _main(args, output_file): scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) num_sentences = 0 has_target = True - with progress_bar.build_progress_bar(args, itr) as t: - wps_meter = TimeMeter() - for sample in t: - sample = utils.move_to_cuda(sample) if use_cuda else sample - if 'net_input' not in sample: - continue - - prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample['target'][:, :args.prefix_size] - - gen_timer.start() - hypos = task.inference_step( - generator, models, sample, prefix_tokens, lm_weight=args.lm_weight, - ) - num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) - gen_timer.stop(num_generated_tokens) - - # obtain nonpad mask of encoder output to plot attentions - if args.print_alignment: - net_input = sample['net_input'] - src_tokens = net_input['src_tokens'] - output_lengths = models[0].encoder.output_lengths(net_input['src_lengths']) - nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) - - for i in range(len(sample['id'])): - has_target = sample['target'] is not None - utt_id = sample['utt_id'][i] - - # Retrieve the original sentences - if has_target: - target_str = sample['target_raw_text'][i] - if not args.quiet: - target_sent = dictionary.tokens_to_sentence( - target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, - ) - print('T-{}\t{}'.format(utt_id, target_sent), file=output_file) - - # Process top predictions - for j, hypo in enumerate(hypos[i][:args.nbest]): - hypo_str = dictionary.string(hypo['tokens'].int().cpu()) # not removing bpe at this point - if not args.quiet or i == 0: - hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) - - if not args.quiet: - score = hypo['score'] / math.log(2) # convert to base 2 - print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, score), file=output_file) - - # Score and obtain attention only the top hypothesis - if j == 0: - # src_len x tgt_len - attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \ - if args.print_alignment and hypo['attention'] is not None else None - if args.print_alignment and attention is not None: - save_dir = os.path.join(args.results_path, 'attn_plots') - os.makedirs(save_dir, exist_ok=True) - plot_attention(attention, hypo_sent, utt_id, save_dir) - scorer.add_prediction(utt_id, hypo_str, bpe_symbol=args.remove_bpe) - if has_target: - scorer.add_evaluation(utt_id, target_str, hypo_str, bpe_symbol=args.remove_bpe) - - wps_meter.update(num_generated_tokens) - t.log({'wps': round(wps_meter.avg)}) - num_sentences += sample['nsentences'] + wps_meter = TimeMeter() + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if 'net_input' not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample['target'][:, :args.prefix_size] + + gen_timer.start() + hypos = task.inference_step( + generator, models, sample, prefix_tokens, lm_weight=args.lm_weight, + ) + num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) + gen_timer.stop(num_generated_tokens) + + # obtain nonpad mask of encoder output to plot attentions + if args.print_alignment: + net_input = sample['net_input'] + src_tokens = net_input['src_tokens'] + output_lengths = models[0].encoder.output_lengths(net_input['src_lengths']) + nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) + + for i in range(len(sample['id'])): + has_target = sample['target'] is not None + utt_id = sample['utt_id'][i] + + # Retrieve the original sentences + if has_target: + target_str = sample['target_raw_text'][i] + if not args.quiet: + target_sent = dictionary.tokens_to_sentence( + target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, + ) + print('T-{}\t{}'.format(utt_id, target_sent), file=output_file) + + # Process top predictions + for j, hypo in enumerate(hypos[i][:args.nbest]): + hypo_str = dictionary.string(hypo['tokens'].int().cpu()) # not removing bpe at this point + if not args.quiet or i == 0: + hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) + + if not args.quiet: + score = hypo['score'] / math.log(2) # convert to base 2 + print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, score), file=output_file) + + # Score and obtain attention only the top hypothesis + if j == 0: + # src_len x tgt_len + attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \ + if args.print_alignment and hypo['attention'] is not None else None + if args.print_alignment and attention is not None: + save_dir = os.path.join(args.results_path, 'attn_plots') + os.makedirs(save_dir, exist_ok=True) + plot_attention(attention, hypo_sent, utt_id, save_dir) + scorer.add_prediction(utt_id, hypo_str, bpe_symbol=args.remove_bpe) + if has_target: + scorer.add_evaluation(utt_id, target_str, hypo_str, bpe_symbol=args.remove_bpe) + + wps_meter.update(num_generated_tokens) + progress.log({'wps': round(wps_meter.avg)}) + num_sentences += sample['nsentences'] logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 106bf8edd..3cf5a1f28 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -17,12 +17,10 @@ import numpy as np import torch -from fairseq import ( - checkpoint_utils, distributed_utils, metrics, options, progress_bar, tasks, utils -) +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import iterators +from fairseq.logging import meters, metrics, progress_bar from fairseq.trainer import Trainer -from fairseq.meters import StopwatchMeter logging.basicConfig( @@ -87,7 +85,7 @@ def main(args, init_distributed=False): max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() - train_meter = StopwatchMeter() + train_meter = meters.StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') while ( @@ -159,8 +157,15 @@ def train(args, trainer, task, epoch_itr): else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - progress = progress_bar.build_progress_bar( - args, itr, epoch_itr.epoch, no_progress_bar='simple', + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch_itr.epoch, + tensorboard_logdir=( + args.tensorboard_logdir if distributed_utils.is_master(args) else None + ), + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch @@ -171,9 +176,6 @@ def train(args, trainer, task, epoch_itr): if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) for samples in progress: - if hasattr(trainer.criterion, 'set_num_updates'): - trainer.criterion.set_num_updates(trainer.get_num_updates()) - log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: @@ -235,10 +237,16 @@ def validate(args, trainer, task, epoch_itr, subsets): shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) - progress = progress_bar.build_progress_bar( - args, itr, epoch_itr.epoch, - prefix='valid on \'{}\' subset'.format(subset), - no_progress_bar='simple' + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch_itr.epoch, + prefix=f"valid on '{subset}' subset", + tensorboard_logdir=( + args.tensorboard_logdir if distributed_utils.is_master(args) else None + ), + default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 40483d34f..e4ccd077c 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -11,9 +11,9 @@ import torch -from fairseq import metrics, search, utils +from fairseq import search, utils from fairseq.data import ConcatDataset - +from fairseq.logging import metrics from fairseq.tasks import FairseqTask, register_task from espresso.data import ( @@ -173,7 +173,9 @@ def load_dictionary(cls, filename, non_lang_syms=None): return AsrDictionary.load(filename, f_non_lang_syms=non_lang_syms) @classmethod - def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): """Disable this method """ raise NotImplementedError @@ -240,11 +242,12 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): self.tgt_dict.count[self.tgt_dict.unk()] = unk_count def build_generator(self, args): - if args.score_reference: + if getattr(args, "score_reference", False): args.score_reference = False logger.warning( "--score-reference is not applicable to speech recognition, ignoring it." ) + from fairseq.sequence_generator import SequenceGenerator # Choose search strategy. Defaults to Beam Search. @@ -272,19 +275,28 @@ def build_generator(self, args): assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: - search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) + search_strategy = search.Sampling( + self.target_dictionary, sampling_topk, sampling_topp + ) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch( - self.target_dictionary, diverse_beam_groups, diverse_beam_strength) + self.target_dictionary, diverse_beam_groups, diverse_beam_strength + ) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( - self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, + self.target_dictionary, + min_len_a=1, + min_len_b=0, + max_len_a=1, + max_len_b=0, ) elif diversity_rate > -1: - search_strategy = search.DiverseSiblingsSearch(self.target_dictionary, diversity_rate) + search_strategy = search.DiverseSiblingsSearch( + self.target_dictionary, diversity_rate + ) else: search_strategy = search.BeamSearch(self.target_dictionary) From 41c57255ea982cbcbd86cd0496e93088a1a87c20 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 5 Mar 2020 03:50:15 -0500 Subject: [PATCH 072/119] code adaptation/changes according to the commits on Mar 3-10, 2020 --- espresso/criterions/cross_entropy_v2.py | 18 +++++----- .../label_smoothed_cross_entropy_v2.py | 32 +++++++++-------- espresso/models/speech_lstm.py | 15 ++++---- espresso/speech_train.py | 35 ++++++++++--------- espresso/tasks/speech_recognition.py | 6 ++-- 5 files changed, 57 insertions(+), 49 deletions(-) diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py index 349e0d08d..45254fd8b 100644 --- a/espresso/criterions/cross_entropy_v2.py +++ b/espresso/criterions/cross_entropy_v2.py @@ -20,11 +20,13 @@ @register_criterion("cross_entropy_v2") class CrossEntropyV2Criterion(CrossEntropyCriterion): - def __init__(self, args, task): - super().__init__(args, task) + def __init__(self, task, sentence_avg, print_interval, remove_bpe): + super().__init__(task, sentence_avg) self.dictionary = task.target_dictionary - self.epoch = 0 + self.print_interval = print_interval + self.remove_bpe = remove_bpe + self.epoch = 1 @staticmethod def add_args(parser): @@ -47,7 +49,7 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample["net_input"], epoch=self.epoch) loss, _, lprobs = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] + sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] logging_output = { "loss": loss.data, "ntokens": sample["ntokens"], @@ -57,8 +59,8 @@ def forward(self, model, sample, reduce=True): if ( hasattr(model, "num_updates") and model.training and - model.num_updates // self.args.print_interval > - (model.num_updates - 1) // self.args.print_interval + model.num_updates // self.print_interval > + (model.num_updates - 1) // self.print_interval ): # print a randomly sampled result every print_interval updates target = model.get_targets(sample, net_output) pred = lprobs.argmax(-1).cpu() # bsz x len @@ -68,11 +70,11 @@ def forward(self, model, sample, reduce=True): ref_tokens = sample["target_raw_text"][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) ref_one = self.dictionary.tokens_to_sentence( - ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, + ref_tokens, use_unk_sym=False, bpe_symbol=self.remove_bpe, ) pred_one = self.dictionary.tokens_to_sentence( self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, - bpe_symbol=self.args.remove_bpe, + bpe_symbol=self.remove_bpe, ) logger.info("sample REF: " + ref_one) logger.info("sample PRD: " + pred_one) diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index 8e3683339..f21f3e14f 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -80,17 +80,21 @@ def label_smoothed_nll_loss( @register_criterion("label_smoothed_cross_entropy_v2") class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): - def __init__(self, args, task): - super().__init__(args, task) + def __init__( + self, task, sentence_avg, label_smoothing, smoothing_type, print_interval, + remove_bpe, unigram_pseudo_count, + ): + super().__init__(task, sentence_avg, label_smoothing) self.dictionary = task.target_dictionary - self.epoch = 0 + self.smoothing_type = smoothing_type + self.print_interval = print_interval + self.remove_bpe = remove_bpe + self.epoch = 1 self.unigram_tensor = None - if args.smoothing_type == "unigram": - self.unigram_tensor = torch.cuda.FloatTensor(self.dictionary.count).unsqueeze(-1) \ - if torch.cuda.is_available() and not args.cpu \ - else torch.FloatTensor(self.dictionary.count).unsqueeze(-1) - self.unigram_tensor += args.unigram_pseudo_count # for further backoff + if smoothing_type == "unigram": + self.unigram_tensor = torch.FloatTensor(self.dictionary.count).unsqueeze(-1) + self.unigram_tensor += unigram_pseudo_count # for further backoff self.unigram_tensor.div_(self.unigram_tensor.sum()) @staticmethod @@ -121,9 +125,9 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample["net_input"], epoch=self.epoch) loss, nll_loss, lprobs = self.compute_loss( - model, net_output, sample, reduce=reduce, smoothing_type=self.args.smoothing_type + model, net_output, sample, reduce=reduce, smoothing_type=self.smoothing_type ) - sample_size = sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] + sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] logging_output = { "loss": loss.data, "nll_loss": nll_loss.data, @@ -134,8 +138,8 @@ def forward(self, model, sample, reduce=True): if ( hasattr(model, "num_updates") and model.training and - model.num_updates // self.args.print_interval > - (model.num_updates - 1) // self.args.print_interval + model.num_updates // self.print_interval > + (model.num_updates - 1) // self.print_interval ): # print a randomly sampled result every print_interval updates target = model.get_targets(sample, net_output) pred = lprobs.argmax(-1).cpu() # bsz x len @@ -145,11 +149,11 @@ def forward(self, model, sample, reduce=True): ref_tokens = sample["target_raw_text"][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) ref_one = self.dictionary.tokens_to_sentence( - ref_tokens, use_unk_sym=False, bpe_symbol=self.args.remove_bpe, + ref_tokens, use_unk_sym=False, bpe_symbol=self.remove_bpe, ) pred_one = self.dictionary.tokens_to_sentence( self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, - bpe_symbol=self.args.remove_bpe, + bpe_symbol=self.remove_bpe, ) logger.info("sample REF: " + ref_one) logger.info("sample PRD: " + pred_one) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 98e5c7096..bb1b24c4e 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -485,14 +485,13 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, - attention weights of shape `(batch, tgt_len, src_len)` """ if self.scheduled_sampling_rate_scheduler is not None: - epoch = kwargs.get('epoch', 0) - if epoch > 0: - sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) - if sampling_prob < 1.0: # apply scheduled sampling - return self._forward_with_scheduled_sampling( - prev_output_tokens, sampling_prob, encoder_out=encoder_out, - incremental_state={}, # use empty dict to preserve forward state - ) + epoch = kwargs.get('epoch', 1) + sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) + if sampling_prob < 1.0: # apply scheduled sampling + return self._forward_with_scheduled_sampling( + prev_output_tokens, sampling_prob, encoder_out=encoder_out, + incremental_state={}, # use empty dict to preserve forward state + ) x, attn_scores = self.extract_features( prev_output_tokens, encoder_out, incremental_state, diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 3cf5a1f28..7f2b989c8 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -57,7 +57,7 @@ def main(args, init_distributed=False): # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): - task.load_dataset(valid_sub_split, combine=False, epoch=0) + task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) @@ -90,11 +90,7 @@ def main(args, init_distributed=False): valid_subsets = args.valid_subset.split(',') while ( lr > args.min_lr - and ( - epoch_itr.epoch < max_epoch - # allow resuming training from the final checkpoint - or epoch_itr._next_epoch_itr is not None - ) + and epoch_itr.next_epoch_idx <= max_epoch and trainer.get_num_updates() < max_update ): # train for one epoch @@ -118,7 +114,7 @@ def main(args, init_distributed=False): break epoch_itr = trainer.get_train_iterator( - epoch_itr.epoch, + epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=(os.pathsep in getattr(args, 'data', '')), ) @@ -127,6 +123,9 @@ def main(args, init_distributed=False): def should_stop_early(args, valid_loss): + # skip check if no validation was done in the current epoch + if valid_loss is None: + return False if args.patience <= 0: return False @@ -149,7 +148,7 @@ def train(args, trainer, task, epoch_itr): # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, - shuffle=(epoch_itr.epoch >= args.curriculum), + shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] @@ -176,14 +175,20 @@ def train(args, trainer, task, epoch_itr): if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) for samples in progress: - log_output = trainer.train_step(samples) - num_updates = trainer.get_num_updates() - if log_output is None: - continue + with metrics.aggregate('train_inner'): + log_output = trainer.train_step(samples) + if log_output is None: # OOM, overflow, ... + continue # log mid-epoch stats - stats = get_training_stats(metrics.get_smoothed_values('train')) - progress.log(stats, tag='train', step=num_updates) + num_updates = trainer.get_num_updates() + if num_updates % args.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values('train_inner')) + progress.log(stats, tag='train_inner', step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters('train_inner') if ( not args.disable_validation @@ -321,8 +326,6 @@ def cli_main(modify_parser=None): port = random.randint(10000, 20000) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_rank = None # set based on device id - if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d': - logger.info('NOTE: you may get faster training with: --ddp-backend=no_c10d') torch.multiprocessing.spawn( fn=distributed_main, args=(args, ), diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index e4ccd077c..9870852d4 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -210,7 +210,7 @@ def setup_task(cls, args, **kwargs): else: return cls(args, tgt_dict) - def load_dataset(self, split, epoch=0, combine=False, **kwargs): + def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. Args: @@ -218,7 +218,7 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 - data_path = paths[epoch % len(paths)] + data_path = paths[(epoch - 1) % len(paths)] self.datasets[split] = get_asr_dataset_from_json( data_path, split, self.tgt_dict, @@ -255,7 +255,7 @@ def build_generator(self, args): sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) - diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5), + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) if ( From 4b8d3be4d490dd52ca6c6053c1c140d5dd9fffe3 Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Wed, 11 Mar 2020 02:21:30 -0400 Subject: [PATCH 073/119] SpecAugment (#21) --- espresso/data/scp_text_dataset.py | 36 ++++++- espresso/data/speech_dataset.py | 7 ++ espresso/tasks/speech_recognition.py | 14 ++- espresso/tools/specaug_interpolate.py | 132 ++++++++++++++++++++++++++ examples/asr_librispeech/run.sh | 27 ++++-- examples/asr_swbd/run.sh | 23 +++-- examples/asr_wsj/run.sh | 12 +-- 7 files changed, 226 insertions(+), 25 deletions(-) create mode 100644 espresso/tools/specaug_interpolate.py diff --git a/espresso/data/scp_text_dataset.py b/espresso/data/scp_text_dataset.py index d39d23914..ea91f1944 100644 --- a/espresso/data/scp_text_dataset.py +++ b/espresso/data/scp_text_dataset.py @@ -10,6 +10,10 @@ import torch +from fairseq.data import data_utils + +from espresso.tools.specaug_interpolate import specaug + try: import kaldi_io except ImportError: @@ -26,6 +30,7 @@ class ScpDataset(torch.utils.data.Dataset): def __init__( self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, + seed=1, specaugment_config: Optional[str] = None, ): super().__init__() assert len(utt_ids) == len(rxfiles) @@ -51,6 +56,9 @@ def __init__( assert len(self.sizes) == self.size self.sizes = np.array(self.sizes, dtype=np.int32) self.feat_dim = feat.shape[1] # feature dimension + self.seed = seed + self.specaugment_config = specaugment_config + self.epoch = 1 def check_index(self, i): if i < 0 or i >= self.size: @@ -68,9 +76,15 @@ def filter_and_reorder(self, indices): self.size = len(self.utt_ids) self.ordered_indices = list(range(self.size)) + def set_epoch(self, epoch): + self.epoch = epoch + def __getitem__(self, i): self.check_index(i) feat = kaldi_io.read_mat(self.rxfiles[i]) + if self.specaugment_config is not None and self.specaugment_config != "": + with data_utils.numpy_seed(self.seed, self.epoch, i): + feat = specaug(feat, **eval(self.specaugment_config)) item = torch.from_numpy(feat).float() return item @@ -91,9 +105,12 @@ class ScpCachedDataset(ScpDataset): def __init__( self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, - ordered_prefetch=False, cache_size=4096, + seed=1, specaugment_config: Optional[str] = None, ordered_prefetch=False, cache_size=4096, ): - super().__init__(utt_ids, rxfiles, utt2num_frames=utt2num_frames) + super().__init__( + utt_ids, rxfiles, utt2num_frames=utt2num_frames, + seed=seed, specaugment_config=specaugment_config, + ) self.cache = None self.cache_index = {} self.cache_size = cache_size # in terms of number of examples @@ -150,7 +167,11 @@ def __getitem__(self, i): self.cache_index[idx] = ptx length = self.sizes[idx] dst = self.cache[ptx: ptx + length] - np.copyto(dst, kaldi_io.read_mat(self.rxfiles[idx])) + feat = kaldi_io.read_mat(self.rxfiles[idx]) + if self.specaugment_config is not None and self.specaugment_config != "": + with data_utils.numpy_seed(self.seed, self.epoch, idx): + feat = specaug(feat, **eval(self.specaugment_config)) + np.copyto(dst, feat) ptx += length ptx = self.cache_index[i] @@ -166,8 +187,12 @@ class ScpInMemoryDataset(ScpDataset): def __init__( self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, + seed=1, specaugment_config: Optional[str] = None, ): - super().__init__(utt_ids, rxfiles, utt2num_frames=utt2num_frames) + super().__init__( + utt_ids, rxfiles, utt2num_frames=utt2num_frames, + seed=seed, specaugment_config=specaugment_config, + ) self.read_data() def read_data(self): @@ -189,6 +214,9 @@ def __getitem__(self, i): self.check_index(i) ptx = self.data_offsets[i] a = self.buffer[ptx: ptx + self.sizes[i]].copy() + if self.specaugment_config is not None and self.specaugment_config != "": + with data_utils.numpy_seed(self.seed, self.epoch, i): + a = specaug(a, **eval(self.specaugment_config)) return torch.from_numpy(a).float() diff --git a/espresso/data/speech_dataset.py b/espresso/data/speech_dataset.py index dcc76a35f..d1c945a5b 100644 --- a/espresso/data/speech_dataset.py +++ b/espresso/data/speech_dataset.py @@ -230,3 +230,10 @@ def supports_prefetch(self): def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if hasattr(self.src, 'set_epoch'): + self.src.set_epoch(epoch) + if self.tgt is not None and hasattr(self.tgt, 'set_epoch'): + self.tgt.set_epoch(epoch) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 9870852d4..6acab68ba 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -31,6 +31,7 @@ def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, max_source_positions, max_target_positions, + seed=1, specaugment_config=None, ): """ Parse data json and create dataset. @@ -72,7 +73,9 @@ def get_asr_dataset_from_json( assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) src_datasets.append(ScpCachedDataset( - utt_ids, feats, utt2num_frames=utt2num_frames, ordered_prefetch=True + utt_ids, feats, utt2num_frames=utt2num_frames, seed=seed, + specaugment_config=specaugment_config if split == "train" else None, + ordered_prefetch=True, )) if len(token_text) > 0: assert len(utt_ids) == len(token_text) @@ -161,6 +164,12 @@ def add_args(parser): help="amount to upsample primary dataset") parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", help="feature input channels") + parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", + help="SpecAugment config string. If not None and not empty, " + "then apply SpecAugment. Should be an evaluatable expression of " + "a python dict. See speech_tools.specaug_interpolate.specaug() for " + "all allowed arguments. Argments not appearing in this string " + "will take on their default values") # fmt: off @classmethod @@ -185,6 +194,7 @@ def __init__(self, args, tgt_dict, word_dict=None): self.tgt_dict = tgt_dict self.word_dict = word_dict self.feat_in_channels = args.feat_in_channels + self.specaugment_config = args.specaugment_config torch.backends.cudnn.deterministic = True # Compansate for the removel of :func:`torch.rand()` from # :func:`fairseq.distributed_utils.distributed_init()` by fairseq, @@ -226,6 +236,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): upsample_primary=self.args.upsample_primary, max_source_positions=self.args.max_source_positions, max_target_positions=self.args.max_target_positions, + seed=self.args.seed, + specaugment_config=self.specaugment_config, ) src_dataset = self.datasets[split].src diff --git a/espresso/tools/specaug_interpolate.py b/espresso/tools/specaug_interpolate.py new file mode 100644 index 000000000..6a7ed33ec --- /dev/null +++ b/espresso/tools/specaug_interpolate.py @@ -0,0 +1,132 @@ +# Copyright (c) Nanxin Chen, Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This implementation is modified from https://github.com/zcaceres/spec_augment + +MIT License + +Copyright (c) 2019 Zach Caceres, Jenny Cai +""" + +import numpy as np + +import torch + + +def specaug(spec, W=80, F=27, T=70, num_freq_masks=2, num_time_masks=2, p=0.2, replace_with_zero=False): + """SpecAugment + + Reference: SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition + (https://arxiv.org/pdf/1904.08779.pdf) + + This implementation modified from https://github.com/zcaceres/spec_augment + + Args: + spec (numpy.ndarray): input tensor of the shape `(T, dim)` + W (int): time warp parameter + F (int): maximum width of each freq mask + T (int): maximum width of each time mask + num_freq_masks (int): number of frequency masks + num_time_masks (int): number of time masks + p (int): toal mask width shouldn't exeed this times num of frames + replace_with_zero (bool): if True, masked parts will be filled with 0, if False, filled with mean + + Returns: + output (numpy.ndarray): resultant matrix of shape `(T, dim)` + """ + spec = torch.from_numpy(spec) + if replace_with_zero: + pad_value = 0. + else: + pad_value = spec.mean() + time_warped = time_warp(spec.transpose(0, 1), W=W) + freq_masked = freq_mask(time_warped, F=F, num_masks=num_freq_masks, pad_value=pad_value) + time_masked = time_mask(freq_masked, T=T, num_masks=num_time_masks, p=p, pad_value=pad_value) + return time_masked.transpose(0, 1).numpy() + + +def time_warp(spec, W=5): + """Time warping + + Args: + spec (torch.Tensor): input tensor of shape `(dim, T)` + W (int): time warp parameter + + Returns: + time warpped tensor (torch.Tensor): output tensor of shape `(dim, T)` + """ + t = spec.size(1) + if t - W <= W + 1: + return spec + center = np.random.randint(W + 1, t - W) + warped = np.random.randint(center - W, center + W + 1) + if warped == center: + return spec + spec = spec.unsqueeze(0).unsqueeze(0) + with torch.no_grad(): # to make the results deterministic + left = torch.nn.functional.interpolate( + spec[:, :, :, :center], size=(spec.size(2), warped), + mode="bicubic", align_corners=False, + ) + right = torch.nn.functional.interpolate( + spec[:, :, :, center:], size=(spec.size(2), t - warped), + mode="bicubic", align_corners=False, + ) + return torch.cat((left, right), dim=-1).squeeze(0).squeeze(0) + + +def freq_mask(spec, F=30, num_masks=1, pad_value=0.): + """Frequency masking + + Args: + spec (torch.Tensor): input tensor of shape `(dim, T)` + F (int): maximum width of each mask + num_masks (int): number of masks + pad_value (float): value for padding + + Returns: + freq masked tensor (torch.Tensor): output tensor of shape `(dim, T)` + """ + cloned = spec.clone() + num_mel_channels = cloned.size(0) + + for i in range(num_masks): + f = np.random.randint(0, F + 1) + f_zero = np.random.randint(0, num_mel_channels - f) + + if f == 0: + return cloned + cloned[f_zero:f_zero + f] = pad_value + return cloned + + +def time_mask(spec, T=40, num_masks=1, p=0.2, pad_value=0.): + """Time masking + + Args: + spec (torch.Tensor): input tensor of shape `(dim, T)` + T (int): maximum width of each mask + num_masks (int): number of masks + p (float): toal mask width shouldn't exeed this times num of frames + pad_value (float): value for padding + + Returns: + time masked tensor (torch.Tensor): output tensor of shape `(dim, T)` + """ + cloned = spec.clone() + len_spectro = cloned.size(1) + T = min(T, int(len_spectro * p)) + if T == 0: + return cloned + + for i in range(num_masks): + t = np.random.randint(0, T + 1) + t_zero = np.random.randint(0, len_spectro - t) + + if t == 0: + return cloned + cloned[:, t_zero:t_zero + t] = pad_value + return cloned diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index a9ac7ef18..91592cacb 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -35,6 +35,7 @@ kaldi_scoring=true # feature configuration do_delta=false +apply_specaug=false . ./path.sh @@ -169,13 +170,13 @@ if [ ${stage} -le 5 ]; then [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ - --log-interval 8000 --log-format simple \ + --log-interval $((16000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 32000 --max-sentences 1024 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 30 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ - --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 8000 \ + --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates $((16000/ngpus)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ --arch lstm_lm_librispeech --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file fi @@ -219,20 +220,28 @@ if [ ${stage} -le 8 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + opts="" + if $apply_specaug; then + opts="$opts --max-epoch 95 --lr-scheduler tri_stage --warmup-steps $((2000/ngpus)) --hold-steps $((600000/ngpus)) --decay-steps $((1040000/ngpus))" + opts="$opts --encoder-rnn-layers 5" + specaug_config="{'W': 80, 'F': 27, 'T': 100, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 1.0}" + else + opts="$opts --max-epoch 30 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10" + fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ - --log-interval 4000 --log-format simple --print-training-sample-interval 2000 \ + --log-interval $((8000/ngpus)) --log-format simple --print-training-sample-interval $((4000/ngpus)) \ --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ - --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 3000 \ + --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((6000/ngpus)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ --dict $dict --remove-bpe sentencepiece \ - --max-source-positions 9999 --max-target-positions 999 2>&1 | tee $log_file + --max-source-positions 9999 --max-target-positions 999 \ + $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file fi if [ ${stage} -le 9 ]; then @@ -243,6 +252,10 @@ if [ ${stage} -le 9 ]; then if $lm_shallow_fusion; then path="$path:$lmdir/$lm_checkpoint" opts="$opts --lm-weight 0.47 --eos-factor 1.5" + if $apply_specaug; then + # overwrite the existing opts + opts="$opts --lm-weight 0.4" + fi decode_affix=shallow_fusion fi for dataset in $test_set; do diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 409a4eea1..433d5538d 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -40,6 +40,7 @@ fi # feature configuration do_delta=false +apply_specaug=false . ./path.sh @@ -207,13 +208,13 @@ if [ $stage -le 4 ]; then [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ - --log-interval 500 --log-format simple \ + --log-interval $((1000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 1024 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ - --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 500 \ + --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates $((1000/ngpus)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 \ --arch lstm_lm_swbd --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file fi @@ -260,20 +261,28 @@ if [ $stage -le 7 ]; then mkdir -p $dir/logs log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + if $apply_specaug; then + opts="$opts --max-epoch 100 --lr-scheduler tri_stage --warmup-steps $((1000/ngpus)) --hold-steps $((140000/ngpus)) --decay-steps $((330000/ngpus))" + opts="$opts --encoder-rnn-hidden-size 1024 --encoder-rnn-layers 5 --decoder-embed-dim 512 --decoder-hidden-size 1024" + opts="$opts --decoder-out-embed-dim 3072 --attention-dim 512 --dropout 0.4" + specaug_config="{'W': 40, 'F': 18, 'T': 70, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 0.2}" + else + opts="$opts --max-epoch 35 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14" + fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ - --log-interval 1500 --log-format simple --print-training-sample-interval 2000 \ + --log-interval $((3000/ngpus)) --log-format simple --print-training-sample-interval $((4000/ngpus)) \ --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ - --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 1500 \ + --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((3000/ngpus)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6 \ --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ - --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file + --max-source-positions 9999 --max-target-positions 999 \ + $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file fi if [ $stage -le 8 ]; then diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index d7f780013..d408e016f 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -194,13 +194,13 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ - --log-interval 2000 --log-format simple \ + --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 128 \ --valid-subset $valid_subset --max-sentences-valid 256 \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ - --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ + --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates $((4000/ngpus)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch lstm_lm_wsj --criterion cross_entropy --sample-break-mode eos 2>&1 | tee $log_file fi @@ -224,13 +224,13 @@ if [ ${stage} -le 6 ] && $use_wordlm; then [ -f $wordlmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $wordlmdict \ - --log-interval 2000 --log-format simple \ + --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 6400 --max-sentences 256 \ --valid-subset $valid_subset --max-sentences-valid 512 \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ - --save-dir $wordlmdir --restore-file checkpoint_last.pt --save-interval-updates 2000 \ + --save-dir $wordlmdir --restore-file checkpoint_last.pt --save-interval-updates $((4000/ngpus)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch lstm_wordlm_wsj --criterion cross_entropy \ --sample-break-mode eos 2>&1 | tee $log_file @@ -285,13 +285,13 @@ if [ ${stage} -le 9 ]; then log_file=$dir/logs/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ - --log-interval 400 --log-format simple --print-training-sample-interval 1000 \ + --log-interval $((800/ngpus)) --log-format simple --print-training-sample-interval $((2000/ngpus)) \ --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 11 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates 400 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((800/ngpus)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.05 --smoothing-type temporal \ From 11e106f14428c3dc7303cfba27ba04a86585f05c Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 11 Mar 2020 12:58:26 -0400 Subject: [PATCH 074/119] code adaptation/changes according to the commits on Mar 11, 2020; change logs/->log/; rename SpeechDataset->AsrDataset, Scp*Dataset->FeatScp*Dataset, score*.sh->score*_e2e.sh; remove validation on train subset from wsj recipe --- .../label_smoothed_cross_entropy_v2.py | 5 +-- espresso/data/__init__.py | 17 +++++--- .../{speech_dataset.py => asr_dataset.py} | 2 +- ...p_text_dataset.py => feat_text_dataset.py} | 6 +-- espresso/speech_recognize.py | 2 +- espresso/speech_train.py | 2 +- espresso/tasks/speech_recognition.py | 18 ++++---- examples/asr_librispeech/local/score.sh | 1 - examples/asr_librispeech/local/score_e2e.sh | 1 + examples/asr_librispeech/run.sh | 16 +++---- .../{score_basic.sh => score_basic_e2e.sh} | 0 .../asr_swbd/local/{score.sh => score_e2e.sh} | 4 +- .../{score_sclite.sh => score_sclite_e2e.sh} | 0 examples/asr_swbd/run.sh | 17 ++++---- .../asr_wsj/local/{score.sh => score_e2e.sh} | 0 examples/asr_wsj/run.sh | 43 ++++++------------- ..._speech_dataset.py => test_asr_dataset.py} | 32 +++++++------- 17 files changed, 78 insertions(+), 88 deletions(-) rename espresso/data/{speech_dataset.py => asr_dataset.py} (99%) rename espresso/data/{scp_text_dataset.py => feat_text_dataset.py} (98%) delete mode 120000 examples/asr_librispeech/local/score.sh create mode 120000 examples/asr_librispeech/local/score_e2e.sh rename examples/asr_swbd/local/{score_basic.sh => score_basic_e2e.sh} (100%) rename examples/asr_swbd/local/{score.sh => score_e2e.sh} (91%) rename examples/asr_swbd/local/{score_sclite.sh => score_sclite_e2e.sh} (100%) rename examples/asr_wsj/local/{score.sh => score_e2e.sh} (100%) rename tests/espresso/{test_speech_dataset.py => test_asr_dataset.py} (89%) diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index f21f3e14f..53e7f3f9d 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -63,9 +63,8 @@ def label_smoothed_nll_loss( raise ValueError("Unsupported smoothing type: {}".format(smoothing_type)) if ignore_index is not None: pad_mask = target.eq(ignore_index) - if pad_mask.any(): - nll_loss.masked_fill_(pad_mask, 0.) - smooth_loss.masked_fill_(pad_mask, 0.) + nll_loss.masked_fill_(pad_mask, 0.) + smooth_loss.masked_fill_(pad_mask, 0.) else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py index d2fcf423c..1c3cae450 100644 --- a/espresso/data/__init__.py +++ b/espresso/data/__init__.py @@ -3,15 +3,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .asr_dataset import AsrDataset from .asr_dictionary import AsrDictionary -from .scp_text_dataset import AsrTextDataset, ScpCachedDataset, ScpDataset, ScpInMemoryDataset -from .speech_dataset import SpeechDataset +from .feat_text_dataset import ( + AsrTextDataset, + FeatScpCachedDataset, + FeatScpDataset, + FeatScpInMemoryDataset, +) __all__ = [ + 'AsrDataset', 'AsrDictionary', 'AsrTextDataset', - 'ScpCachedDataset', - 'ScpDataset', - 'ScpInMemoryDataset', - 'SpeechDataset', + 'FeatScpCachedDataset', + 'FeatScpDataset', + 'FeatScpInMemoryDataset', ] diff --git a/espresso/data/speech_dataset.py b/espresso/data/asr_dataset.py similarity index 99% rename from espresso/data/speech_dataset.py rename to espresso/data/asr_dataset.py index d1c945a5b..602fc7fa0 100644 --- a/espresso/data/speech_dataset.py +++ b/espresso/data/asr_dataset.py @@ -80,7 +80,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): return batch -class SpeechDataset(FairseqDataset): +class AsrDataset(FairseqDataset): """ A pair of torch.utils.data.Datasets. diff --git a/espresso/data/scp_text_dataset.py b/espresso/data/feat_text_dataset.py similarity index 98% rename from espresso/data/scp_text_dataset.py rename to espresso/data/feat_text_dataset.py index ea91f1944..eb70e863f 100644 --- a/espresso/data/scp_text_dataset.py +++ b/espresso/data/feat_text_dataset.py @@ -20,7 +20,7 @@ raise ImportError('Please install kaldi_io with: pip install kaldi_io') -class ScpDataset(torch.utils.data.Dataset): +class FeatScpDataset(torch.utils.data.Dataset): """ A dataset for audio features prepared in Kaldi scp format (e.g., feats.scp). See http://kaldi-asr.org/doc/tutorial_running.html#tutorial_running_feats @@ -96,7 +96,7 @@ def exists(path): return os.path.exists(path) -class ScpCachedDataset(ScpDataset): +class FeatScpCachedDataset(FeatScpDataset): """ This class loads a batch of feature matrices (specified as *cache_size*) every time an entry is inquired. The inquire order should be known in advance. @@ -179,7 +179,7 @@ def __getitem__(self, i): return torch.from_numpy(a).float() -class ScpInMemoryDataset(ScpDataset): +class FeatScpInMemoryDataset(FeatScpDataset): """ This class loads all feature matrices into memory at once. It has the maximum memory usage and least I/O. diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 36941d4d7..c44117f6b 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -70,7 +70,7 @@ def _main(args, output_file): # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - args.path.split(os.pathsep), + utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, ) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 7f2b989c8..79d04f77a 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -139,7 +139,7 @@ def is_better(a, b): return False else: should_stop_early.num_runs += 1 - return should_stop_early.num_runs > args.patience + return should_stop_early.num_runs >= args.patience @metrics.aggregate('train') diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 6acab68ba..ab8b2eca2 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -17,10 +17,10 @@ from fairseq.tasks import FairseqTask, register_task from espresso.data import ( + AsrDataset, AsrDictionary, AsrTextDataset, - ScpCachedDataset, - SpeechDataset, + FeatScpCachedDataset, ) @@ -72,7 +72,7 @@ def get_asr_dataset_from_json( utt2num_frames.append(int(val["utt2num_frames"])) assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) - src_datasets.append(ScpCachedDataset( + src_datasets.append(FeatScpCachedDataset( utt_ids, feats, utt2num_frames=utt2num_frames, seed=seed, specaugment_config=specaugment_config if split == "train" else None, ordered_prefetch=True, @@ -107,7 +107,7 @@ def get_asr_dataset_from_json( tgt_dataset = None tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None - return SpeechDataset( + return AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, tgt_dict, @@ -329,7 +329,7 @@ def build_generator(self, args): ) def build_dataset_for_inference(self, src_tokens, src_lengths): - return SpeechDataset(src_tokens, src_lengths) + return AsrDataset(src_tokens, src_lengths) def build_model(self, args): # build the greedy decoder for validation with WER @@ -353,10 +353,10 @@ def inference_step(self, generator, models, sample, prefix_tokens=None, lm_weigh def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) - word_error = utils.item(sum(log.get("word_error", 0) for log in logging_outputs)) - word_count = utils.item(sum(log.get("word_count", 0) for log in logging_outputs)) - char_error = utils.item(sum(log.get("char_error", 0) for log in logging_outputs)) - char_count = utils.item(sum(log.get("char_count", 0) for log in logging_outputs)) + word_error = sum(log.get("word_error", 0) for log in logging_outputs) + word_count = sum(log.get("word_count", 0) for log in logging_outputs) + char_error = sum(log.get("char_error", 0) for log in logging_outputs) + char_count = sum(log.get("char_count", 0) for log in logging_outputs) if word_count > 0: metrics.log_scalar("wer", float(word_error) / word_count * 100, word_count, round=4) if char_count > 0: diff --git a/examples/asr_librispeech/local/score.sh b/examples/asr_librispeech/local/score.sh deleted file mode 120000 index 3a771d6c9..000000000 --- a/examples/asr_librispeech/local/score.sh +++ /dev/null @@ -1 +0,0 @@ -../../asr_wsj/local/score.sh \ No newline at end of file diff --git a/examples/asr_librispeech/local/score_e2e.sh b/examples/asr_librispeech/local/score_e2e.sh new file mode 120000 index 000000000..4d684c2cc --- /dev/null +++ b/examples/asr_librispeech/local/score_e2e.sh @@ -0,0 +1 @@ +../../asr_wsj/local/score_e2e.sh \ No newline at end of file diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 91592cacb..944ca2519 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -144,10 +144,10 @@ fi lmdict=$dict if [ ${stage} -le 4 ]; then echo "Stage 4: Text Binarization for subword LM Training" - mkdir -p $lmdatadir/logs + mkdir -p $lmdatadir/log for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') - ${decode_cmd} $lmdatadir/logs/preprocess.log \ + ${decode_cmd} $lmdatadir/log/preprocess.log \ python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ @@ -165,8 +165,8 @@ fi if [ ${stage} -le 5 ]; then echo "Stage 5: subword LM Training" valid_subset=valid - mkdir -p $lmdir/logs - log_file=$lmdir/logs/train.log + mkdir -p $lmdir/log + log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ @@ -188,7 +188,7 @@ if [ ${stage} -le 6 ]; then for i in $(seq $num); do gen_set_array[$i]="test$i"; done test_set_array=($test_set) for i in $(seq 0 $num); do - log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log + log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ @@ -217,8 +217,8 @@ fi if [ ${stage} -le 8 ]; then echo "Stage 8: Model Training" valid_subset=valid - mkdir -p $dir/logs - log_file=$dir/logs/train.log + mkdir -p $dir/log + log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" opts="" if $apply_specaug; then @@ -270,7 +270,7 @@ if [ ${stage} -le 9 ]; then echo "log saved in ${decode_dir}/decode.log" if $kaldi_scoring; then echo "verify WER by scoring with Kaldi..." - local/score.sh data/$dataset $decode_dir + local/score_e2e.sh data/$dataset $decode_dir cat ${decode_dir}/scoring_kaldi/wer fi done diff --git a/examples/asr_swbd/local/score_basic.sh b/examples/asr_swbd/local/score_basic_e2e.sh similarity index 100% rename from examples/asr_swbd/local/score_basic.sh rename to examples/asr_swbd/local/score_basic_e2e.sh diff --git a/examples/asr_swbd/local/score.sh b/examples/asr_swbd/local/score_e2e.sh similarity index 91% rename from examples/asr_swbd/local/score.sh rename to examples/asr_swbd/local/score_e2e.sh index 78739eae4..2128fba00 100755 --- a/examples/asr_swbd/local/score.sh +++ b/examples/asr_swbd/local/score_e2e.sh @@ -28,8 +28,8 @@ data=$1 if [ -f $data/stm ]; then # use sclite scoring. echo "$data/stm exists: using local/score_sclite.sh" - eval local/score_sclite.sh $orig_args + eval local/score_sclite_e2e.sh $orig_args else echo "$data/stm does not exist: using local/score_basic.sh" - eval local/score_basic.sh $orig_args + eval local/score_basic_e2e.sh $orig_args fi diff --git a/examples/asr_swbd/local/score_sclite.sh b/examples/asr_swbd/local/score_sclite_e2e.sh similarity index 100% rename from examples/asr_swbd/local/score_sclite.sh rename to examples/asr_swbd/local/score_sclite_e2e.sh diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 433d5538d..f3de6287f 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -182,10 +182,10 @@ fi lmdict=$dict if [ $stage -le 3 ]; then echo "Stage 3: Text Binarization for subword LM Training" - mkdir -p $lmdatadir/logs + mkdir -p $lmdatadir/log test_paths= && for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') - ${decode_cmd} $lmdatadir/logs/preprocess.log \ + ${decode_cmd} $lmdatadir/log/preprocess.log \ python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ @@ -203,8 +203,8 @@ fi if [ $stage -le 4 ]; then echo "Stage 4: subword LM Training" valid_subset=valid - mkdir -p $lmdir/logs - log_file=$lmdir/logs/train.log + mkdir -p $lmdir/log + log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ @@ -226,7 +226,7 @@ if [ $stage -le 5 ]; then for i in $(seq $num); do gen_set_array[$i]="test$i"; done #gen_set_array=(test test1 test2) test_set_array=($test_set) for i in $(seq 0 $num); do - log_file=$lmdir/logs/evaluation_${test_set_array[$i]}.log + log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ @@ -258,8 +258,8 @@ if [ $stage -le 7 ]; then valid_subset=valid opts="" [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" - mkdir -p $dir/logs - log_file=$dir/logs/train.log + mkdir -p $dir/log + log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" if $apply_specaug; then opts="$opts --max-epoch 100 --lr-scheduler tri_stage --warmup-steps $((1000/ngpus)) --hold-steps $((140000/ngpus)) --decay-steps $((330000/ngpus))" @@ -298,6 +298,7 @@ if [ $stage -le 8 ]; then fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" for dataset in $test_set; do + decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ --num-shards 1 --shard-id 0 --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms --gen-subset $dataset \ @@ -307,7 +308,7 @@ if [ $stage -le 8 ]; then echo "log saved in ${decode_dir}/decode.log" echo "Scoring with kaldi..." - local/score.sh data/$dataset $decode_dir + local/score_e2e.sh data/$dataset $decode_dir if [ "$dataset" == "train_dev" ]; then echo -n "tran_dev: " && cat $decode_dir/scoring/wer | grep WER elif [ "$dataset" == "eval2000" ] || [ "$dataset" == "rt03" ]; then diff --git a/examples/asr_wsj/local/score.sh b/examples/asr_wsj/local/score_e2e.sh similarity index 100% rename from examples/asr_wsj/local/score.sh rename to examples/asr_wsj/local/score_e2e.sh diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index d408e016f..4358f2619 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -16,7 +16,6 @@ train_set=train_si284 valid_set=test_dev93 test_set=test_eval92 checkpoint=checkpoint_best.pt -validate_on_train_subset=false # for monitoring E2E model training # LM related lm_affix= @@ -34,7 +33,6 @@ if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then wsj0=/export/corpora5/LDC/LDC93S6B wsj1=/export/corpora5/LDC/LDC94S13B fi -train_subset_size=500 # for validation if validate_on_train_subset is set to true kaldi_scoring=true # feature configuration @@ -66,7 +64,6 @@ if [ ${stage} -le 0 ]; then fi train_feat_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${train_feat_dir} -train_subset_feat_dir=${dumpdir}/${train_set}_${train_subset_size}/delta${do_delta}; mkdir -p ${train_subset_feat_dir} valid_feat_dir=${dumpdir}/${valid_set}/delta${do_delta}; mkdir -p ${valid_feat_dir} test_feat_dir=${dumpdir}/${test_set}/delta${do_delta}; mkdir -p ${test_feat_dir} if [ ${stage} -le 1 ]; then @@ -93,11 +90,6 @@ if [ ${stage} -le 1 ]; then data/${valid_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/valid ${valid_feat_dir} dump.sh --cmd "$train_cmd" --nj 4 --do_delta $do_delta \ data/${test_set}/feats.scp data/${train_set}/cmvn.ark exp/dump_feats/test ${test_feat_dir} - - # randomly select a subset of train set for optional diagnosis - utils/subset_data_dir.sh data/${train_set} ${train_subset_size} data/${train_set}_${train_subset_size} - utils/filter_scp.pl data/${train_set}_${train_subset_size}/utt2spk ${train_feat_dir}/feats.scp \ - > ${train_subset_feat_dir}/feats.scp fi dict=data/lang/${train_set}_units.txt @@ -115,7 +107,7 @@ if [ ${stage} -le 2 ]; then cat $nlsyms echo "$0: making a dictionary and tokenizing text for train/valid/test set..." - for dataset in $train_set ${train_set}_${train_subset_size} $valid_set $test_set; do + for dataset in $train_set $valid_set $test_set; do text=data/$dataset/text token_text=data/$dataset/token_text text2token.py --skip-ncols 1 --space "" --non-lang-syms $nlsyms $text > $token_text @@ -159,8 +151,8 @@ if [ ${stage} -le 3 ]; then echo "Stage 3: Text Binarization for LM Training" if ! $use_wordlm; then echo "$0: binarizing char text..." - mkdir -p $lmdatadir/logs - ${decode_cmd} $lmdatadir/logs/preprocess.log \ + mkdir -p $lmdatadir/log + ${decode_cmd} $lmdatadir/log/preprocess.log \ python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 30 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ @@ -169,8 +161,8 @@ if [ ${stage} -le 3 ]; then --destdir $lmdatadir else echo "$0: binarizing word text..." - mkdir -p $wordlmdatadir/logs - ${decode_cmd} $wordlmdatadir/logs/preprocess.log \ + mkdir -p $wordlmdatadir/log + ${decode_cmd} $wordlmdatadir/log/preprocess.log \ python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 30 --srcdict $wordlmdict --only-source \ --trainpref $wordlmdatadir/train \ @@ -189,8 +181,8 @@ fi if [ ${stage} -le 4 ] && ! $use_wordlm; then echo "Stage 4: char LM Training" valid_subset=valid - mkdir -p $lmdir/logs - log_file=$lmdir/logs/train.log + mkdir -p $lmdir/log + log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ @@ -208,7 +200,7 @@ fi if [ ${stage} -le 5 ] && ! $use_wordlm; then echo "Stage 5: char LM Evaluation" for gen_subset in valid test; do - log_file=$lmdir/logs/evaluation_$gen_subset.log + log_file=$lmdir/log/evaluation_$gen_subset.log python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ --max-tokens 192000 --max-sentences 256 --sample-break-mode eos \ @@ -219,8 +211,8 @@ fi if [ ${stage} -le 6 ] && $use_wordlm; then echo "Stage 6: word LM Training" valid_subset=valid - mkdir -p $wordlmdir/logs - log_file=$wordlmdir/logs/train.log + mkdir -p $wordlmdir/log + log_file=$wordlmdir/log/train.log [ -f $wordlmdir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $wordlmdict \ @@ -239,7 +231,7 @@ fi if [ ${stage} -le 7 ] && $use_wordlm; then echo "Stage 7: word LM Evaluation" for gen_subset in valid test; do - log_file=$wordlmdir/logs/evaluation_$gen_subset.log + log_file=$wordlmdir/log/evaluation_$gen_subset.log python3 ../../eval_lm.py $wordlmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $wordlmdict --gen-subset $gen_subset \ --max-tokens 12800 --max-sentences 512 --sample-break-mode eos \ @@ -255,12 +247,8 @@ if [ ${stage} -le 8 ]; then valid_feat=$valid_feat_dir/feats.scp valid_token_text=data/$valid_set/token_text valid_utt2num_frames=data/$valid_set/utt2num_frames - train_subset_feat=$train_subset_feat_dir/feats.scp - train_subset_token_text=data/${train_set}_${train_subset_size}/token_text - train_subset_utt2num_frames=data/${train_set}_${train_subset_size}/utt2num_frames asr_prep_json.py --feat-files $train_feat --token-text-files $train_token_text --utt2num-frames-files $train_utt2num_frames --output data/train.json asr_prep_json.py --feat-files $valid_feat --token-text-files $valid_token_text --utt2num-frames-files $valid_utt2num_frames --output data/valid.json - asr_prep_json.py --feat-files $train_subset_feat --token-text-files $train_subset_token_text --utt2num-frames-files $train_subset_utt2num_frames --output data/train_subset.json for dataset in $valid_set $test_set; do if [ "$dataset" == "$valid_set" ]; then feat=$valid_feat_dir/feats.scp @@ -277,12 +265,9 @@ if [ ${stage} -le 9 ]; then echo "Stage 9: Model Training" opts="" valid_subset=valid - if $validate_on_train_subset; then - valid_subset="$valid_subset,train_subset" - fi [ -f local/wer_output_filter ] && opts="$opts --wer-output-filter local/wer_output_filter" - mkdir -p $dir/logs - log_file=$dir/logs/train.log + mkdir -p $dir/log + log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval $((800/ngpus)) --log-format simple --print-training-sample-interval $((2000/ngpus)) \ @@ -329,7 +314,7 @@ if [ ${stage} -le 10 ]; then echo "log saved in ${decode_dir}/decode.log" if $kaldi_scoring; then echo "verify WER by scoring with Kaldi..." - local/score.sh data/$dataset $decode_dir + local/score_e2e.sh data/$dataset $decode_dir cat ${decode_dir}/scoring_kaldi/wer fi done diff --git a/tests/espresso/test_speech_dataset.py b/tests/espresso/test_asr_dataset.py similarity index 89% rename from tests/espresso/test_speech_dataset.py rename to tests/espresso/test_asr_dataset.py index cae61ce50..d894e7d51 100644 --- a/tests/espresso/test_speech_dataset.py +++ b/tests/espresso/test_asr_dataset.py @@ -11,11 +11,11 @@ import torch from espresso.data import ( + AsrDataset, AsrDictionary, AsrTextDataset, - ScpCachedDataset, - ScpInMemoryDataset, - SpeechDataset, + FeatScpCachedDataset, + FeatScpInMemoryDataset, ) try: @@ -24,7 +24,7 @@ raise ImportError('Please install kaldi_io with: pip install kaldi_io') -class TestSpeechDataset(unittest.TestCase): +class TestAsrDataset(unittest.TestCase): @staticmethod def make_dictionary(): @@ -99,11 +99,11 @@ def setUp(self): self.cuda = torch.cuda.is_available() - def _speech_dataset_helper( + def _asr_dataset_helper( self, all_in_memory=False, ordered_prefetch=False, has_utt2num_frames=False, ): if not all_in_memory: - src_dataset = ScpCachedDataset( + src_dataset = FeatScpCachedDataset( utt_ids=self.feats_utt_ids, rxfiles=self.rxfiles, utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, @@ -111,7 +111,7 @@ def _speech_dataset_helper( cache_size=self.cache_size, ) else: - src_dataset = ScpInMemoryDataset( + src_dataset = FeatScpInMemoryDataset( utt_ids=self.feats_utt_ids, rxfiles=self.rxfiles, utt2num_frames=self.utt2num_frames if has_utt2num_frames else None, @@ -122,7 +122,7 @@ def _speech_dataset_helper( dictionary=self.dictionary, ) - dataset = SpeechDataset( + dataset = AsrDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset.sizes, self.dictionary, left_pad_source=False, @@ -170,17 +170,17 @@ def _speech_dataset_helper( tgt_tokens[j], ) - def test_speech_dataset_cached_no_ordered_prefetch(self): - self._speech_dataset_helper(all_in_memory=False, ordered_prefetch=False) + def test_asr_dataset_cached_no_ordered_prefetch(self): + self._asr_dataset_helper(all_in_memory=False, ordered_prefetch=False) - def test_speech_dataset_cached_with_ordered_prefetch(self): - self._speech_dataset_helper(all_in_memory=False, ordered_prefetch=True) + def test_asr_dataset_cached_with_ordered_prefetch(self): + self._asr_dataset_helper(all_in_memory=False, ordered_prefetch=True) - def test_speech_dataset_all_in_memory(self): - self._speech_dataset_helper(all_in_memory=True) + def test_asr_dataset_all_in_memory(self): + self._asr_dataset_helper(all_in_memory=True) - def test_speech_dataset_has_utt2num_frames(self): - self._speech_dataset_helper(has_utt2num_frames=True) + def test_asr_dataset_has_utt2num_frames(self): + self._asr_dataset_helper(has_utt2num_frames=True) def assertTensorEqual(self, t1, t2): self.assertEqual(t1.size(), t2.size(), "size mismatch") From c05a3360d4667dfecf29aeae6279801a13957720 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 20 Mar 2020 00:29:42 -0400 Subject: [PATCH 075/119] fix specaug indexing --- espresso/tools/specaug_interpolate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/espresso/tools/specaug_interpolate.py b/espresso/tools/specaug_interpolate.py index 6a7ed33ec..2caf62ee5 100644 --- a/espresso/tools/specaug_interpolate.py +++ b/espresso/tools/specaug_interpolate.py @@ -95,7 +95,7 @@ def freq_mask(spec, F=30, num_masks=1, pad_value=0.): for i in range(num_masks): f = np.random.randint(0, F + 1) - f_zero = np.random.randint(0, num_mel_channels - f) + f_zero = np.random.randint(0, num_mel_channels - f + 1) if f == 0: return cloned @@ -124,7 +124,7 @@ def time_mask(spec, T=40, num_masks=1, p=0.2, pad_value=0.): for i in range(num_masks): t = np.random.randint(0, T + 1) - t_zero = np.random.randint(0, len_spectro - t) + t_zero = np.random.randint(0, len_spectro - t + 1) if t == 0: return cloned From 8772f69b4f1b723908bbe5a0840e304bc56da990 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 24 Mar 2020 16:01:19 -0400 Subject: [PATCH 076/119] code adaptation/changes according to the commits on Mar 24-Apr 3, 2020; use data.encoders.{bpe,tokenizer} for wordpiece decode --- espresso/criterions/cross_entropy_v2.py | 12 +-- .../label_smoothed_cross_entropy_v2.py | 12 +-- espresso/data/asr_dictionary.py | 80 +++++++--------- espresso/data/encoders/__init__.py | 15 +++ espresso/data/encoders/characters_asr.py | 37 ++++++++ espresso/speech_recognize.py | 38 +++++--- espresso/speech_train.py | 93 +++++++++++-------- espresso/tasks/speech_recognition.py | 8 +- espresso/tools/wer.py | 10 +- examples/asr_librispeech/run.sh | 6 +- examples/asr_swbd/run.sh | 6 +- examples/asr_wsj/run.sh | 6 +- tests/espresso/test_asr_dataset.py | 26 +++--- tests/espresso/test_speech_utils.py | 74 ++++++++------- 14 files changed, 240 insertions(+), 183 deletions(-) create mode 100644 espresso/data/encoders/__init__.py create mode 100644 espresso/data/encoders/characters_asr.py diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py index 45254fd8b..1888d69b9 100644 --- a/espresso/criterions/cross_entropy_v2.py +++ b/espresso/criterions/cross_entropy_v2.py @@ -20,12 +20,11 @@ @register_criterion("cross_entropy_v2") class CrossEntropyV2Criterion(CrossEntropyCriterion): - def __init__(self, task, sentence_avg, print_interval, remove_bpe): + def __init__(self, task, sentence_avg, print_interval): super().__init__(task, sentence_avg) self.dictionary = task.target_dictionary self.print_interval = print_interval - self.remove_bpe = remove_bpe self.epoch = 1 @staticmethod @@ -69,13 +68,8 @@ def forward(self, model, sample, reduce=True): i = np.random.randint(0, len(sample["id"])) ref_tokens = sample["target_raw_text"][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - ref_one = self.dictionary.tokens_to_sentence( - ref_tokens, use_unk_sym=False, bpe_symbol=self.remove_bpe, - ) - pred_one = self.dictionary.tokens_to_sentence( - self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, - bpe_symbol=self.remove_bpe, - ) + ref_one = self.dictionary.wordpiece_decode(ref_tokens) + pred_one = self.dictionary.wordpiece_decode(self.dictionary.string(pred.data[i][:length])) logger.info("sample REF: " + ref_one) logger.info("sample PRD: " + pred_one) diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index 53e7f3f9d..383a4a1f7 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -81,14 +81,13 @@ class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): def __init__( self, task, sentence_avg, label_smoothing, smoothing_type, print_interval, - remove_bpe, unigram_pseudo_count, + unigram_pseudo_count, ): super().__init__(task, sentence_avg, label_smoothing) self.dictionary = task.target_dictionary self.smoothing_type = smoothing_type self.print_interval = print_interval - self.remove_bpe = remove_bpe self.epoch = 1 self.unigram_tensor = None if smoothing_type == "unigram": @@ -147,13 +146,8 @@ def forward(self, model, sample, reduce=True): i = np.random.randint(0, len(sample["id"])) ref_tokens = sample["target_raw_text"][i] length = utils.strip_pad(target.data[i], self.padding_idx).size(0) - ref_one = self.dictionary.tokens_to_sentence( - ref_tokens, use_unk_sym=False, bpe_symbol=self.remove_bpe, - ) - pred_one = self.dictionary.tokens_to_sentence( - self.dictionary.string(pred.data[i][:length]), use_unk_sym=True, - bpe_symbol=self.remove_bpe, - ) + ref_one = self.dictionary.wordpiece_decode(ref_tokens) + pred_one = self.dictionary.wordpiece_decode(self.dictionary.string(pred.data[i][:length])) logger.info("sample REF: " + ref_one) logger.info("sample PRD: " + pred_one) diff --git a/espresso/data/asr_dictionary.py b/espresso/data/asr_dictionary.py index 570c94add..8dc4610b1 100644 --- a/espresso/data/asr_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -5,9 +5,11 @@ import torch -from fairseq.data import Dictionary, data_utils +from fairseq.data import Dictionary, encoders from fairseq.file_io import PathManager -from fairseq.tokenizer import tokenize_line + +# will automatically load modules defined from there +from espresso.data import encoders as encoders_espresso class AsrDictionary(Dictionary): @@ -35,36 +37,8 @@ def __init__( self.add_symbol(s, n=0) self.nspecial = len(self.symbols) self.non_lang_syms = None - - def string(self, tensor, bpe_symbol=None, escape_unk=False): - """Helper for converting a tensor of token indices to a string. - - Can optionally remove BPE symbols or escape words. - - We overwrite this since we would like to also ignore . - """ - if torch.is_tensor(tensor) and tensor.dim() == 2: - return "\n".join(self.string(t, bpe_symbol, escape_unk) for t in tensor) - - def token_string(i): - if i == self.unk(): - return self.unk_string(escape_unk) - else: - return self[i] - - if hasattr(self, "bos_index"): - sent = " ".join( - token_string(i) - for i in tensor - if (i != self.eos()) and (i != self.bos()) and (i != self.pad()) - ) - else: - sent = " ".join( - token_string(i) - for i in tensor - if (i != self.eos()) and (i != self.pad()) - ) - return data_utils.process_bpe_symbol(sent, bpe_symbol) + self.tokenizer = None + self.bpe = None def bos(self): """Disallow beginning-of-sentence symbol""" @@ -119,20 +93,28 @@ def dummy_sentence(self, length): t[-1] = self.eos() return t - def tokens_to_sentence( - self, line, line_tokenizer=tokenize_line, use_unk_sym=True, bpe_symbol=None, - ): - if bpe_symbol is not None: - return data_utils.process_bpe_symbol(line, bpe_symbol) - # use_unk_sym=False when we want to restore original transcripts from - # token sequences, e.g., obtain reference to compute WER - tokens = line_tokenizer(line) - sent = "" - for token in tokens: - if token == self.space_word: - sent += " " - elif use_unk_sym and self.index(token) == self.unk_index: - sent += self.unk_word - elif token != self.pad_word and token != self.eos_word: - sent += token - return sent.strip() + def build_tokenizer(self, args): + self.tokenizer = encoders.build_tokenizer(args) + + def build_bpe(self, args): + if args.bpe == "characters_asr": + self.bpe = encoders.build_bpe( + args, space_symbol=self.space_word, ends_with_space=True, + non_lang_syms=self.non_lang_syms, + ) + else: + self.bpe = encoders.build_bpe(args) + + def wordpiece_encode(self, x): + if self.tokenizer is not None: + x = self.tokenizer.encode(x) + if self.bpe is not None: + x = self.bpe.encode(x) + return x + + def wordpiece_decode(self, x): + if self.bpe is not None: + x = self.bpe.decode(x) + if self.tokenizer is not None: + x = self.tokenizer.decode(x) + return x diff --git a/espresso/data/encoders/__init__.py b/espresso/data/encoders/__init__.py new file mode 100644 index 000000000..d8c6cd0fe --- /dev/null +++ b/espresso/data/encoders/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os + + +# automatically import any Python files in the encoders/ directory +for file in os.listdir(os.path.dirname(__file__)): + if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): + module = file[:file.find(".py")] + importlib.import_module("espresso.data.encoders." + module) diff --git a/espresso/data/encoders/characters_asr.py b/espresso/data/encoders/characters_asr.py new file mode 100644 index 000000000..ef424150f --- /dev/null +++ b/espresso/data/encoders/characters_asr.py @@ -0,0 +1,37 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional + +from fairseq.data.encoders import register_bpe + +from espresso.tools.utils import tokenize + + +@register_bpe('characters_asr') +class CharactersAsr(object): + + @staticmethod + def add_args(parser): + pass + + def __init__( + self, args, space_symbol="", ends_with_space=True, + non_lang_syms: Optional[List[str]] = None, + ): + self.space_symbol = space_symbol + self.ends_with_space = ends_with_space + self.non_lang_syms = non_lang_syms + + def encode(self, x: str) -> str: + y = tokenize(x, space=self.space_symbol, non_lang_syms=self.non_lang_syms) + if self.ends_with_space: + return y + " " + self.space_symbol + else: + return y + + def decode(self, x: str) -> str: + return x.replace(" ", "").replace(self.space_symbol, " ").strip() diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index c44117f6b..f7ffd1d4d 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -16,6 +16,7 @@ import torch from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.data import encoders from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel @@ -139,7 +140,18 @@ def _main(args, output_file): 'The option match_source_len is not applicable to speech recognition. Ignoring it.' ) gen_timer = StopwatchMeter() - generator = task.build_generator(args) + generator = task.build_generator(models, args) + + # Handle tokenization and BPE + tokenizer = encoders.build_tokenizer(args) + bpe = encoders.build_bpe(args) + + def decode_fn(x): + if bpe is not None: + x = bpe.decode(x) + if tokenizer is not None: + x = tokenizer.decode(x) + return x # Generate and compute WER scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) @@ -177,20 +189,20 @@ def _main(args, output_file): if has_target: target_str = sample['target_raw_text'][i] if not args.quiet: - target_sent = dictionary.tokens_to_sentence( - target_str, use_unk_sym=False, bpe_symbol=args.remove_bpe, - ) - print('T-{}\t{}'.format(utt_id, target_sent), file=output_file) + detok_target_str = decode_fn(target_str) + print('T-{}\t{}'.format(utt_id, detok_target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): - hypo_str = dictionary.string(hypo['tokens'].int().cpu()) # not removing bpe at this point - if not args.quiet or i == 0: - hypo_sent = dictionary.tokens_to_sentence(hypo_str, bpe_symbol=args.remove_bpe) - + hypo_str = dictionary.string( + hypo['tokens'].int().cpu(), + bpe_symbol=None, + extra_symbols_to_ignore={dictionary.pad()}, + ) # not removing bpe at this point + detok_hypo_str = decode_fn(hypo_str) if not args.quiet: score = hypo['score'] / math.log(2) # convert to base 2 - print('H-{}\t{}\t{}'.format(utt_id, hypo_sent, score), file=output_file) + print('H-{}\t{}\t{}'.format(utt_id, detok_hypo_str, score), file=output_file) # Score and obtain attention only the top hypothesis if j == 0: @@ -200,10 +212,10 @@ def _main(args, output_file): if args.print_alignment and attention is not None: save_dir = os.path.join(args.results_path, 'attn_plots') os.makedirs(save_dir, exist_ok=True) - plot_attention(attention, hypo_sent, utt_id, save_dir) - scorer.add_prediction(utt_id, hypo_str, bpe_symbol=args.remove_bpe) + plot_attention(attention, detok_hypo_str, utt_id, save_dir) + scorer.add_prediction(utt_id, hypo_str) if has_target: - scorer.add_evaluation(utt_id, target_str, hypo_str, bpe_symbol=args.remove_bpe) + scorer.add_evaluation(utt_id, target_str, hypo_str) wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 79d04f77a..aadd3a362 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -21,6 +21,7 @@ from fairseq.data import iterators from fairseq.logging import meters, metrics, progress_bar from fairseq.trainer import Trainer +from fairseq.model_parallel.megatron_trainer import MegatronTrainer logging.basicConfig( @@ -37,6 +38,7 @@ def main(args, init_distributed=False): assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' + metrics.reset() # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: @@ -70,7 +72,11 @@ def main(args, init_distributed=False): )) # Build trainer - trainer = Trainer(args, task, model, criterion) + if args.model_parallel_size == 1: + trainer = Trainer(args, task, model, criterion) + else: + trainer = MegatronTrainer(args, task, model, criterion) + logger.info('training on {} GPUs'.format(args.distributed_world_size)) logger.info('max input frames per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, @@ -87,32 +93,18 @@ def main(args, init_distributed=False): lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - valid_subsets = args.valid_subset.split(',') while ( lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch - and trainer.get_num_updates() < max_update ): # train for one epoch - train(args, trainer, task, epoch_itr) - - if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) - else: - valid_losses = [None] + valid_losses = train(args, trainer, task, epoch_itr, max_update) + if should_stop_early(args, valid_losses[0]) or trainer.get_num_updates() >= max_update: + break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) - # save checkpoint - if epoch_itr.epoch % args.save_interval == 0: - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - - # early stop - if should_stop_early(args, valid_losses[0]): - logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) - break - epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch @@ -139,12 +131,16 @@ def is_better(a, b): return False else: should_stop_early.num_runs += 1 - return should_stop_early.num_runs >= args.patience + if should_stop_early.num_runs >= args.patience: + logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) + return True + else: + return False @metrics.aggregate('train') -def train(args, trainer, task, epoch_itr): - """Train the model for one epoch.""" +def train(args, trainer, task, epoch_itr, max_update=math.inf): + """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, @@ -170,10 +166,10 @@ def train(args, trainer, task, epoch_itr): # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) - valid_subsets = args.valid_subset.split(',') - max_update = args.max_update or math.inf if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) + + valid_subsets = args.valid_subset.split(',') for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) @@ -190,16 +186,8 @@ def train(args, trainer, task, epoch_itr): # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') - if ( - not args.disable_validation - and args.save_interval_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates > 0 - ): - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - - if num_updates >= max_update: + valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) + if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats @@ -208,6 +196,41 @@ def train(args, trainer, task, epoch_itr): # reset epoch-level meters metrics.reset_meters('train') + return valid_losses + + +def validate_and_save(args, trainer, task, epoch_itr, valid_subsets): + num_updates = trainer.get_num_updates() + do_save = ( + ( + args.save_interval_updates > 0 + and num_updates > 0 + and num_updates % args.save_interval_updates == 0 + ) + or ( + epoch_itr.end_of_epoch() + and epoch_itr.epoch % args.save_interval == 0 + ) + ) + do_validate = ( + ( + do_save # saving requires validation + or ( + epoch_itr.end_of_epoch() + and epoch_itr.epoch % args.validate_interval == 0 + ) + ) + and not args.disable_validation + ) + + # Validate + valid_losses = [None] + if do_validate: + valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + # Save + if do_save: + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + return valid_losses def get_training_stats(stats): @@ -298,10 +321,6 @@ def print_options_meaning_changes(args): def cli_main(modify_parser=None): parser = options.get_training_parser() - parser.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, - help='remove BPE tokens before scoring ' - '(can be set to sentencepiece). Being used for monitoring ' - 'and validation') args = options.parse_args_and_arch(parser, modify_parser=modify_parser) print_options_meaning_changes(args) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index ab8b2eca2..8f6cb4a19 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -192,6 +192,8 @@ def build_dictionary( def __init__(self, args, tgt_dict, word_dict=None): super().__init__(args) self.tgt_dict = tgt_dict + self.tgt_dict.build_tokenizer(args) + self.tgt_dict.build_bpe(args) self.word_dict = word_dict self.feat_in_channels = args.feat_in_channels self.specaugment_config = args.specaugment_config @@ -253,7 +255,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): unk_count += (tgt_dataset[i][0] == self.tgt_dict.unk()).int().sum().item() self.tgt_dict.count[self.tgt_dict.unk()] = unk_count - def build_generator(self, args): + def build_generator(self, models, args): if getattr(args, "score_reference", False): args.score_reference = False logger.warning( @@ -390,9 +392,7 @@ def _inference_with_wer(self, decoder, sample, model): utt_id = sample["utt_id"][i] ref_tokens = sample["target_raw_text"][i] pred_tokens = self.target_dictionary.string(pred.data[i]) - scorer.add_evaluation( - utt_id, ref_tokens, pred_tokens, bpe_symbol=self.args.remove_bpe, - ) + scorer.add_evaluation(utt_id, ref_tokens, pred_tokens) return ( scorer.tot_word_error(), scorer.tot_word_count(), scorer.tot_char_error(), scorer.tot_char_count(), diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index 353a80cb3..07c802883 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -46,7 +46,7 @@ def parse_wer_output_filter(self, wer_output_filter): else: logger.warning('Unsupported pattern: "{}". Ignoring it'.format(line)) - def add_prediction(self, utt_id, pred, bpe_symbol=None): + def add_prediction(self, utt_id, pred): if not isinstance(utt_id, str): raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) if not isinstance(pred, str): @@ -56,12 +56,12 @@ def add_prediction(self, utt_id, pred, bpe_symbol=None): 'Duplicated utterance id detected: {}'.format(utt_id) self.char_results[utt_id] = pred + '\n' - pred_words = self.dictionary.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) + pred_words = self.dictionary.wordpiece_decode(pred) assert utt_id not in self.results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.results[utt_id] = pred_words + '\n' - def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): + def add_evaluation(self, utt_id, ref, pred): if not isinstance(utt_id, str): raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) if not isinstance(ref, str): @@ -84,8 +84,8 @@ def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): self.char_counter += counter # word level counts - ref_words = self.dictionary.tokens_to_sentence(ref, use_unk_sym=False, bpe_symbol=bpe_symbol) - pred_words = self.dictionary.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) + ref_words = self.dictionary.wordpiece_decode(ref) + pred_words = self.dictionary.wordpiece_decode(pred) # filter words according to self.word_filters (support re.sub only) for pattern, repl in self.word_filters: diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 944ca2519..f8051492e 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -239,7 +239,7 @@ if [ ${stage} -le 8 ]; then --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ - --dict $dict --remove-bpe sentencepiece \ + --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model \ --max-source-positions 9999 --max-target-positions 999 \ $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file fi @@ -262,8 +262,8 @@ if [ ${stage} -le 9 ]; then decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --max-sentences 24 \ - --num-shards 1 --shard-id 0 --dict $dict --remove-bpe sentencepiece --gen-subset $dataset \ - --max-source-positions 9999 --max-target-positions 999 \ + --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model \ + --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index f3de6287f..718fe10de 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -280,7 +280,7 @@ if [ $stage -le 7 ]; then --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.1 --smoothing-type uniform \ --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6 \ - --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms \ + --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file fi @@ -301,8 +301,8 @@ if [ $stage -le 8 ]; then decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ - --num-shards 1 --shard-id 0 --dict $dict --remove-bpe sentencepiece --non-lang-syms $nlsyms --gen-subset $dataset \ - --max-source-positions 9999 --max-target-positions 999 \ + --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model \ + --non-lang-syms $nlsyms --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 4358f2619..60c775339 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -281,7 +281,7 @@ if [ ${stage} -le 9 ]; then --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_v2 \ --label-smoothing 0.05 --smoothing-type temporal \ --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6 \ - --dict $dict --non-lang-syms $nlsyms \ + --dict $dict --bpe characters_asr --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi @@ -306,8 +306,8 @@ if [ ${stage} -le 10 ]; then decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --max-sentences 32 \ - --num-shards 1 --shard-id 0 --dict $dict --non-lang-syms $nlsyms --gen-subset $dataset \ - --max-source-positions 9999 --max-target-positions 999 \ + --num-shards 1 --shard-id 0 --dict $dict --bpe characters_asr --non-lang-syms $nlsyms \ + --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts --print-alignment diff --git a/tests/espresso/test_asr_dataset.py b/tests/espresso/test_asr_dataset.py index d894e7d51..2101253df 100644 --- a/tests/espresso/test_asr_dataset.py +++ b/tests/espresso/test_asr_dataset.py @@ -21,7 +21,7 @@ try: import kaldi_io except ImportError: - raise ImportError('Please install kaldi_io with: pip install kaldi_io') + raise ImportError("Please install kaldi_io with: pip install kaldi_io") class TestAsrDataset(unittest.TestCase): @@ -33,9 +33,9 @@ def make_dictionary(): alphabet = string.ascii_lowercase for token in alphabet: d.add_symbol(token) - d.add_symbol('') + d.add_symbol("") d.finalize(padding_factor=1) # don't add extra padding symbols - d.space_index = d.indices.get('', -1) + d.space_index = d.indices.get("", -1) return d @staticmethod @@ -45,14 +45,14 @@ def generate_feats(test_dir, num=10, seed=0): np.random.seed(seed) utt_ids, rxfiles, utt2num_frames = [], [], [] for i in range(num): - utt_id = 'utt_id_' + str(i) - ark_file = os.path.join(test_dir, 'mat_' + str(i) + '.ark') + utt_id = "utt_id_" + str(i) + ark_file = os.path.join(test_dir, "mat_" + str(i) + ".ark") length = np.random.randint(200, 800) m = np.random.uniform(-10.0, 10.0, (length, 40)) expected_feats[utt_id] = m kaldi_io.write_mat(ark_file, m) utt_ids.append(utt_id) - rxfiles.append(ark_file + ':0') + rxfiles.append(ark_file + ":0") utt2num_frames.append(length) return expected_feats, utt_ids, rxfiles, utt2num_frames @@ -62,13 +62,13 @@ def generate_text(test_dir, num=10, seed=0): order from those in feats.scp.""" expected_text = {} alphabet = string.ascii_lowercase - space = '' + space = "" vocab = list(alphabet) vocab.append(space) np.random.seed(seed) utt_ids, token_text = [], [] for i in np.random.permutation(range(num)): - utt_id = 'utt_id_' + str(i) + utt_id = "utt_id_" + str(i) length = np.random.randint(10, 100) tokens = [ vocab[np.random.randint(0, len(vocab))] for _ in range(length) @@ -79,11 +79,11 @@ def generate_text(test_dir, num=10, seed=0): tokens[-1] = vocab[np.random.randint(0, len(vocab) - 1)] expected_text[utt_id] = tokens utt_ids.append(utt_id) - token_text.append(' '.join(tokens)) + token_text.append(" ".join(tokens)) return expected_text, utt_ids, token_text def setUp(self): - self.test_dir = './temp' + self.test_dir = "./temp" os.makedirs(self.test_dir, exist_ok=True) self.num_audios = 150 self.num_transripts = 100 @@ -155,8 +155,10 @@ def _asr_dataset_helper( self.assertEqual(bsz, len(batch_sampler[i])) src_frames = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] - tgt_tokens = self.dictionary.string(batch["target"]).split('\n') - tgt_tokens = [line.split(' ') for line in tgt_tokens] + tgt_tokens = self.dictionary.string( + batch["target"], extra_symbols_to_ignore={self.dictionary.pad()} + ).split("\n") + tgt_tokens = [line.split(" ") for line in tgt_tokens] self.assertEqual(bsz, src_frames.size(0)) self.assertEqual(bsz, src_lengths.numel()) self.assertEqual(bsz, len(tgt_tokens)) diff --git a/tests/espresso/test_speech_utils.py b/tests/espresso/test_speech_utils.py index c3eb10ce9..96173be1e 100644 --- a/tests/espresso/test_speech_utils.py +++ b/tests/espresso/test_speech_utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import logging import unittest import string @@ -26,13 +27,16 @@ def make_dictionary(vocab, non_lang_syms=[]): """construct dictionary.""" assert isinstance(vocab, list) and isinstance(non_lang_syms, list) d = AsrDictionary() + d.non_lang_syms = non_lang_syms + args = Namespace(bpe="characters_asr") + d.build_bpe(args) for token in vocab: d.add_symbol(token) - d.add_symbol('') + d.add_symbol("") for token in non_lang_syms: d.add_symbol(token) d.finalize(padding_factor=1) # don't add extra padding symbols - d.space_index = d.indices.get('', -1) + d.space_index = d.indices.get("", -1) return d @staticmethod @@ -42,27 +46,27 @@ def generate_text(vocab, oovs=[], non_lang_syms=[], seed=0): isinstance(non_lang_syms, list) np.random.seed(seed) sent_len = np.random.randint(2, 30) - sent = '' + sent = "" for _ in range(sent_len): if len(non_lang_syms) > 0 and np.random.randint(0, 20) == 0: word = non_lang_syms[np.random.randint(0, len(non_lang_syms))] else: - word = '' + word = "" word_len = np.random.randint(2, 11) for _ in range(word_len): if len(oovs) > 0 and np.random.randint(0, 20) == 0: word += oovs[np.random.randint(0, len(oovs))] else: word += vocab[np.random.randint(0, len(vocab))] - sent += word + ' ' + sent += word + " " - sent = ' '.join(sent.strip().split(' ')) + sent = " ".join(sent.strip().split(" ")) return sent def setUp(self): self.vocab = list(string.ascii_lowercase) self.oovs = list(string.ascii_uppercase) - self.non_lang_syms = ['', '', ''] + self.non_lang_syms = ["", "", ""] self.num_sentences = 100 self.dictionary = self.make_dictionary( self.vocab, @@ -74,39 +78,38 @@ def setUp(self): def test_speech_tokenizer(self): for i, sent in enumerate(self.text): - logger.info('test sentence {}:'.format(i)) + logger.info("test sentence {}:".format(i)) logger.info(sent) - tokens = utils.tokenize( - sent, space=self.dictionary.space_word, - non_lang_syms=self.non_lang_syms, - ) + tokens = self.dictionary.wordpiece_encode(sent) - # test :func:`~speech_tools.utils.tokenize` with + # test :func:`~AsrDictionary.wordpiece_encode` with # :func:`~AsrDictionary.encode_line` tensor = self.dictionary.encode_line( tokens, add_if_not_exist=False, append_eos=True, ) - reconstructed_tokens = self.dictionary.string(tensor) - expected_tokens = ' '.join( + reconstructed_tokens = self.dictionary.string( + tensor, extra_symbols_to_ignore={self.dictionary.pad()} + ) + expected_tokens = " ".join( [token if self.dictionary.index(token) != self.dictionary.unk() else - self.dictionary.unk_word for token in tokens.split(' ')] + self.dictionary.unk_word for token in tokens.split(" ")] ) self.assertEqual(reconstructed_tokens, expected_tokens) - # test :func:`~speech_tools.utils.tokenize` with - # :func:`~AsrDictionary.tokens_to_sentence` - reconstructed_sent = self.dictionary.tokens_to_sentence(tokens) + # test :func:`~AsrDictionary.wordpiece_encode` with + # :func:`~AsrDictionary.wordpiece_decode` + reconstructed_sent = self.dictionary.wordpiece_decode(reconstructed_tokens) expected_sent = [] - words = sent.split(' ') + words = sent.split(" ") for w in words: if w not in self.non_lang_syms: - new_word = ''.join( + new_word = "".join( [self.dictionary.unk_word if c in self.oovs else c for c in w] ) expected_sent.append(new_word) else: expected_sent.append(w) - expected_sent = ' '.join(expected_sent) + expected_sent = " ".join(expected_sent) self.assertEqual(reconstructed_sent, expected_sent) def test_collate_frames(self): @@ -179,44 +182,43 @@ def test_edit_distance(self): dist, steps, counter = utils.edit_distance(ref, hyp) self.assertEqual( counter, - Counter({'words': 0, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0}), + Counter({"words": 0, "corr": 0, "sub": 0, "ins": 0, "del": 0}), ) self.assertEqual(steps, []) - ref, hyp = ['a', 'b', 'c'], [] + ref, hyp = ["a", "b", "c"], [] dist, steps, counter = utils.edit_distance(ref, hyp) self.assertEqual( counter, - Counter({'words': 3, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 3}), + Counter({"words": 3, "corr": 0, "sub": 0, "ins": 0, "del": 3}), ) - self.assertEqual(steps, ['del', 'del', 'del']) + self.assertEqual(steps, ["del", "del", "del"]) - ref, hyp = ['a', 'b', 'c'], ['a', 'b', 'c'] + ref, hyp = ["a", "b", "c"], ["a", "b", "c"] dist, steps, counter = utils.edit_distance(ref, hyp) self.assertEqual( counter, - Counter({'words': 3, 'corr': 3, 'sub': 0, 'ins': 0, 'del': 0}), + Counter({"words": 3, "corr": 3, "sub": 0, "ins": 0, "del": 0}), ) - self.assertEqual(steps, ['corr', 'corr', 'corr']) + self.assertEqual(steps, ["corr", "corr", "corr"]) - ref, hyp = ['a', 'b', 'c'], ['d', 'b', 'c', 'e', 'f'] + ref, hyp = ["a", "b", "c"], ["d", "b", "c", "e", "f"] dist, steps, counter = utils.edit_distance(ref, hyp) self.assertEqual( counter, - Counter({'words': 3, 'corr': 2, 'sub': 1, 'ins': 2, 'del': 0}), + Counter({"words": 3, "corr": 2, "sub": 1, "ins": 2, "del": 0}), ) - self.assertEqual(steps, ['sub', 'corr', 'corr', 'ins', 'ins']) + self.assertEqual(steps, ["sub", "corr", "corr", "ins", "ins"]) - ref, hyp = ['b', 'c', 'd', 'e', 'f', 'h'], \ - ['d', 'b', 'c', 'e', 'f', 'g'] + ref, hyp = ["b", "c", "d", "e", "f", "h"], ["d", "b", "c", "e", "f", "g"] dist, steps, counter = utils.edit_distance(ref, hyp) self.assertEqual( counter, - Counter({'words': 6, 'corr': 4, 'sub': 1, 'ins': 1, 'del': 1}), + Counter({"words": 6, "corr": 4, "sub": 1, "ins": 1, "del": 1}), ) self.assertEqual( steps, - ['ins', 'corr', 'corr', 'del', 'corr', 'corr', 'sub'], + ["ins", "corr", "corr", "del", "corr", "corr", "sub"], ) def assertTensorEqual(self, t1, t2): From e40b033f69efdf102669bc709ad7ed80662c93e2 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 7 Apr 2020 18:20:41 -0400 Subject: [PATCH 077/119] code adaptation/changes according to the commits on Apr 7 --- espresso/speech_recognize.py | 4 +- espresso/tasks/speech_recognition.py | 17 ++-- espresso/tools/simple_greedy_decoder.py | 64 ++++++++----- fairseq/sequence_generator.py | 120 +++++++++++++++++++++++- 4 files changed, 165 insertions(+), 40 deletions(-) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index f7ffd1d4d..ff8fbbcd8 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -168,9 +168,7 @@ def decode_fn(x): prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() - hypos = task.inference_step( - generator, models, sample, prefix_tokens, lm_weight=args.lm_weight, - ) + hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 8f6cb4a19..f60e94de4 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -118,7 +118,7 @@ def get_asr_dataset_from_json( ) -@register_task('speech_recognition_espresso') +@register_task("speech_recognition_espresso") class SpeechRecognitionEspressoTask(FairseqTask): """ Transcribe from speech (source) to token text (target). @@ -315,6 +315,7 @@ def build_generator(self, models, args): search_strategy = search.BeamSearch(self.target_dictionary) return SequenceGenerator( + models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), @@ -327,6 +328,7 @@ def build_generator(self, models, args): match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, + lm_weight = getattr(args, "lm_weight", 0.0), eos_factor=getattr(args, "eos_factor", None), ) @@ -334,10 +336,13 @@ def build_dataset_for_inference(self, src_tokens, src_lengths): return AsrDataset(src_tokens, src_lengths) def build_model(self, args): + model = super().build_model(args) # build the greedy decoder for validation with WER from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder - self.decoder_for_validation = SimpleGreedyDecoder(self.target_dictionary, for_validation=True) - return super().build_model(args) + self.decoder_for_validation = SimpleGreedyDecoder( + [model], self.target_dictionary, for_validation=True, + ) + return model def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) @@ -347,12 +352,6 @@ def valid_step(self, sample, model, criterion): ) = self._inference_with_wer(self.decoder_for_validation, sample, model) return loss, sample_size, logging_output - def inference_step(self, generator, models, sample, prefix_tokens=None, lm_weight=0.0): - with torch.no_grad(): - return generator.generate( - models, sample, prefix_tokens=prefix_tokens, lm_weight=lm_weight, - ) - def reduce_metrics(self, logging_outputs, criterion): super().reduce_metrics(logging_outputs, criterion) word_error = sum(log.get("word_error", 0) for log in logging_outputs) diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index 593b2041a..84a1f3403 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -3,21 +3,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + import numpy as np import torch +import torch.nn as nn +from torch import Tensor -class SimpleGreedyDecoder(object): +class SimpleGreedyDecoder(nn.Module): def __init__( - self, dictionary, max_len_a=0, max_len_b=200, temperature=1., for_validation=True, + self, models, dictionary, max_len_a=0, max_len_b=200, retain_dropout=False, + temperature=1.0, for_validation=True, ): """Decode given speech audios with the simple greedy search. Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting dictionary (~fairseq.data.Dictionary): dictionary max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length + retain_dropout (bool, optional): use dropout when generating + (default: False) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) @@ -26,19 +35,32 @@ def __init__( whether a tensor of lprobs is returned. If true, target should be not None """ + super().__init__() + from fairseq.sequence_generator import EnsembleModel + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) self.pad = dictionary.pad() self.unk = dictionary.unk() self.eos = dictionary.eos() self.vocab_size = len(dictionary) self.max_len_a = max_len_a self.max_len_b = max_len_b + self.retain_dropout = retain_dropout self.temperature = temperature - assert temperature > 0, '--temperature must be greater than 0' + assert temperature > 0, "--temperature must be greater than 0" + if not self.retain_dropout: + self.model.eval() self.for_validation = for_validation + def cuda(self): + self.model.cuda() + return self + @torch.no_grad() - def decode(self, models, sample, **kwargs): - """Generate a batch of translations. + def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + """Generate a batch of translations. Match the api of other fairseq generators. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models @@ -46,27 +68,23 @@ def decode(self, models, sample, **kwargs): bos_token (int, optional): beginning of sentence token (default: self.eos) """ - from fairseq.sequence_generator import EnsembleModel - model = EnsembleModel(models) - return self._decode(model, sample, **kwargs) + self.model.reset_incremental_state() + return self._decode(sample, **kwargs) @torch.no_grad() - def _decode(self, model, sample, bos_token=None, **kwargs): - model.eval() - - # model.forward normally channels prev_output_tokens into the decoder - # separately, but SimpleGreedyDecoder directly calls model.encoder - encoder_input = { - k: v for k, v in sample["net_input"].items() - if k != "prev_output_tokens" - } + def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None): + encoder_input: Dict[str, Tensor] = {} + for k, v in sample["net_input"].items(): + if k != "prev_output_tokens": + encoder_input[k] = v src_tokens = encoder_input["src_tokens"] input_size = src_tokens.size() - # batch dimension goes first followed by source lengths - bsz = input_size[0] - src_len = input_size[1] + bsz, src_len = input_size[0], input_size[1] - encoder_outs = model.forward_encoder(encoder_input) + encoder_outs = self.model.forward_encoder( + src_tokens=encoder_input["src_tokens"], + src_lengths=encoder_input["src_lengths"], + ) target = sample["target"] # target can only be None if not for validation assert target is not None or not self.for_validation @@ -79,7 +97,7 @@ def _decode(self, model, sample, bos_token=None, **kwargs): min( int(self.max_len_a * src_len + self.max_len_b), # exclude the EOS marker - model.max_decoder_positions() - 1, + self.model.max_decoder_positions() - 1, ) tokens = src_tokens.new(bsz, max_len + 2).long().fill_(self.pad) @@ -97,7 +115,7 @@ def _decode(self, model, sample, bos_token=None, **kwargs): if attn is not None: attn = attn[:, :, :step + 1] break - log_probs, avg_attn_scores = model.forward_decoder( + log_probs, avg_attn_scores = self.model.forward_decoder( tokens[:, :step + 1], encoder_outs, temperature=self.temperature, ) tokens[:, step + 1] = log_probs.argmax(-1) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index bd46f9e5b..7a8aac47f 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -13,6 +13,8 @@ from fairseq.models import FairseqIncrementalDecoder from torch import Tensor +from espresso.models.external_language_model import RawOutExternalLanguageModelBase + class SequenceGenerator(nn.Module): def __init__( @@ -34,6 +36,7 @@ def __init__( symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, + eos_factor=None, ): """Generates translations of a given source sentence. @@ -61,7 +64,7 @@ def __init__( if isinstance(models, EnsembleModel): self.model = models else: - self.model = EnsembleModel(models) + self.model = EnsembleModel(models) if lm_weight == 0.0 else LMFusionModel(models, lm_weight) self.tgt_dict = tgt_dict self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() @@ -85,7 +88,9 @@ def __init__( self.temperature = temperature self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size + self.eos_factor = eos_factor assert temperature > 0, "--temperature must be greater than 0" + assert eos_factor is None or eos_factor >= 1.0, "--eos-factor must be >= 1.0 if set" self.search = ( search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy @@ -193,10 +198,13 @@ def _generate( if "src_tokens" in net_input: src_tokens = net_input["src_tokens"] - # length of the source text being the character length except EndOfSentence and pad - src_lengths = ( - (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) - ) + if src_tokens.dim() > 2: + src_lengths = encoder_input["src_lengths"] + else: + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + ) elif "source" in net_input: src_tokens = net_input["source"] src_lengths = ( @@ -331,6 +339,11 @@ def _generate( if step >= max_len: lprobs[:, : self.eos] = -math.inf lprobs[:, self.eos + 1 :] = -math.inf + elif self.eos_factor is not None: + # only consider EOS if its score is no less than a specified + # factor of the best candidate score + disallow_eos_mask = lprobs[:, self.eos] < self.eos_factor * lprobs.max(dim=1)[0] + lprobs[disallow_eos_mask, self.eos] = -math.inf # handle prefix tokens (possibly with different lengths) if ( @@ -1007,3 +1020,100 @@ def forward_align(self, src_tokens, src_lengths, prev_output_tokens): if len(self.models) > 1: avg_attn.div_(len(self.models)) return avg_attn + + +class LMFusionModel(EnsembleModel): + """A wrapper around an ensemble of an LM fused model.""" + + def __init__(self, models, lm_weight): + super().__init__(models) + self.lm_weight = lm_weight + assert self.models_size == 2, "Only support LM fusion with one E2E model" + assert self.has_encoder() + + @torch.jit.export + def forward_encoder(self, src_tokens, src_lengths): + return [ + model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) if hasattr(model, "encoder") \ + else None for model in self.models + ] + + @torch.jit.export + def forward_decoder( + self, tokens, encoder_outs: List[EncoderOut], temperature: float = 1.0 + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + attn_count = 0 + encoder_out: Optional[EncoderOut] = None + for i, model in enumerate(self.models): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=self.incremental_states[i], + ) + else: + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + + if isinstance(model, RawOutExternalLanguageModelBase): + probs = decoder_out_tuple[0] + else: + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if i == 1 and self.lm_weight != 1.0: # assuming LM is the last model + probs.mul_(self.lm_weight) + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + attn_count += 1 + avg_probs = torch.sum(torch.stack(log_probs, dim=0), dim=0) + if avg_attn is not None: + avg_attn.div_(attn_count) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[EncoderOut] = [] + for i, model in enumerate(self.models): + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) if hasattr(model, "encoder") else None + ) + return new_outs From b9c19c2e63ced0b55d34247c695f1da47809fb0e Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 13 Apr 2020 22:55:40 -0400 Subject: [PATCH 078/119] update the qsub script for gpu jobs; code adaptation/changes according to the commits on Apr 16 --- espresso/models/external_language_model.py | 78 ++- espresso/models/lstm_lm.py | 146 +++--- espresso/models/speech_lstm.py | 458 ++++++++++-------- .../tensorized_lookahead_language_model.py | 28 +- examples/asr_librispeech/cmd.sh | 4 +- examples/asr_swbd/cmd.sh | 4 +- examples/asr_wsj/cmd.sh | 4 +- 7 files changed, 379 insertions(+), 343 deletions(-) diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index 505a2ca48..dd418ff97 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -6,7 +6,6 @@ import math import torch -from fairseq import utils from fairseq.models import FairseqIncrementalDecoder, FairseqLanguageModel from espresso.data import AsrDictionary @@ -100,9 +99,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) - cached_state = utils.get_incremental_state( - self.lm_decoder, incremental_state, 'cached_state', - ) + cached_state = self.lm_decoder.get_incremental_state(incremental_state, 'cached_state') if cached_state is None: # it is the first time step assert (prev_output_tokens == self.subword_eos_idx).all(), \ @@ -115,14 +112,14 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): cumsum_probs = torch.cumsum(lm_probs, dim=-1) # B x 1 x V nodes = [self.lexroot] * bsz else: - cumsum_probs = utils.get_incremental_state(self, incremental_state, 'cumsum_probs') - nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') + nodes = self.get_incremental_state(incremental_state, 'nodes') assert len(nodes) == bsz w = prev_output_tokens.new([ node.word_idx if node is not None and node.word_idx >= 0 else self.word_unk_idx for node in nodes ]).unsqueeze(-1) # B x 1 - old_cached_state = _clone_cached_state(cached_state) + old_cached_state = _clone_cached_state(self.lm_decoder.get_cached_state(incremental_state)) # recompute cumsum_probs from inter-word transition probabilities # only for those whose prev_output_token is lm_probs = self.lm_decoder.get_normalized_probs( @@ -145,8 +142,8 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): else: # no path in the tree nodes[i] = None - utils.set_incremental_state(self, incremental_state, 'cumsum_probs', cumsum_probs) - utils.set_incremental_state(self, incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, 'cumsum_probs', cumsum_probs) + self.set_incremental_state(incremental_state, 'nodes', nodes) # initialize out_probs (B x 1 x V) if self.open_vocab: @@ -257,21 +254,16 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - cumsum_probs = utils.get_incremental_state( - self, incremental_state, 'cumsum_probs') + cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - utils.set_incremental_state( - self, incremental_state, 'cumsum_probs', new_cumsum_probs, - ) + self.set_incremental_state(incremental_state, 'cumsum_probs', new_cumsum_probs) - nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, 'nodes') if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - utils.set_incremental_state( - self, incremental_state, 'nodes', new_nodes, - ) + self.set_incremental_state(incremental_state, 'nodes', new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" @@ -344,10 +336,12 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) batch_not_space_mask = ~batch_space_mask - wordlm_cached_state = utils.get_incremental_state( - self.wordlm_decoder, incremental_state, 'cached_state') - subwordlm_cached_state = utils.get_incremental_state( - self.subwordlm_decoder, incremental_state, 'cached_state') + wordlm_cached_state = self.wordlm_decoder.get_incremental_state( + incremental_state, 'cached_state', + ) + subwordlm_cached_state = self.subwordlm_decoder.get_incremental_state( + incremental_state, 'cached_state', + ) if wordlm_cached_state is None: # it is the first time step assert subwordlm_cached_state is None @@ -368,16 +362,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): subword_cumlogprobs = out_logprobs.new_zeros(sw.size()) nodes = [self.lexroot] * bsz else: - wordlm_logprobs = utils.get_incremental_state( - self, incremental_state, 'wordlm_logprobs', - ) - out_logprobs = utils.get_incremental_state( - self, incremental_state, 'out_logprobs', - ) - subword_cumlogprobs = utils.get_incremental_state( - self, incremental_state, 'subword_cumlogprobs', - ) - nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + wordlm_logprobs = self.get_incremental_state(incremental_state, 'wordlm_logprobs') + out_logprobs = self.get_incremental_state(incremental_state, 'out_logprobs') + subword_cumlogprobs = self.get_incremental_state(incremental_state, 'subword_cumlogprobs') + nodes = self.get_incremental_state(incremental_state, 'nodes') assert len(nodes) == bsz w = prev_output_tokens.new([ node.word_idx if node is not None and node.word_idx >= 0 else @@ -435,13 +423,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_oov_mask = batch_not_space_mask & ~batch_is_child_mask out_logprobs[batch_oov_mask] = self.logzero - utils.set_incremental_state( - self, incremental_state, 'wordlm_logprobs', wordlm_logprobs, - ) - utils.set_incremental_state( - self, incremental_state, 'subword_cumlogprobs', subword_cumlogprobs, - ) - utils.set_incremental_state(self, incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, 'wordlm_logprobs', wordlm_logprobs) + self.set_incremental_state(incremental_state, 'subword_cumlogprobs', subword_cumlogprobs) + self.set_incremental_state(incremental_state, 'nodes', nodes) # apply word-level probabilies for emitting w = prev_output_tokens.new([ @@ -468,9 +452,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_logprobs[batch_space_mask, :, self.subword_eos_idx] += \ wordlm_logprobs[batch_space_mask, :, self.word_eos_idx] - utils.set_incremental_state( - self, incremental_state, 'out_logprobs', out_logprobs, - ) + self.set_incremental_state(incremental_state, 'out_logprobs', out_logprobs) # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in @@ -481,20 +463,16 @@ def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) for state_name in ['wordlm_logprobs', 'out_logprobs', 'subword_cumlogprobs']: - state = utils.get_incremental_state(self, incremental_state, state_name) + state = self.get_incremental_state(incremental_state, state_name) if state is not None: new_state = state.index_select(0, new_order) - utils.set_incremental_state( - self, incremental_state, state_name, new_state, - ) + self.set_incremental_state(incremental_state, state_name, new_state) - nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, 'nodes') if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - utils.set_incremental_state( - self, incremental_state, 'nodes', new_nodes, - ) + self.set_incremental_state(incremental_state, 'nodes', new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" diff --git a/espresso/models/lstm_lm.py b/espresso/models/lstm_lm.py index 151046d55..5639a2196 100644 --- a/espresso/models/lstm_lm.py +++ b/espresso/models/lstm_lm.py @@ -18,7 +18,7 @@ DEFAULT_MAX_TARGET_POSITIONS = 1e5 -@register_model('lstm_lm_espresso') +@register_model("lstm_lm_espresso") class LSTMLanguageModelEspresso(FairseqLanguageModel): def __init__(self, decoder, args): super().__init__(decoder) @@ -28,37 +28,37 @@ def __init__(self, decoder, args): def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-embed-path', type=str, metavar='STR', - help='path to pre-trained decoder embedding') - parser.add_argument('--decoder-freeze-embed', action='store_true', - help='freeze decoder embeddings') - parser.add_argument('--decoder-hidden-size', type=int, metavar='N', - help='decoder hidden size') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='number of decoder layers') - parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', - help='decoder output embedding dimension') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion') - parser.add_argument('--share-embed', + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--decoder-embed-dim", type=int, metavar="N", + help="decoder embedding dimension") + parser.add_argument("--decoder-embed-path", type=str, metavar="STR", + help="path to pre-trained decoder embedding") + parser.add_argument("--decoder-freeze-embed", action="store_true", + help="freeze decoder embeddings") + parser.add_argument("--decoder-hidden-size", type=int, metavar="N", + help="decoder hidden size") + parser.add_argument("--decoder-layers", type=int, metavar="N", + help="number of decoder layers") + parser.add_argument("--decoder-out-embed-dim", type=int, metavar="N", + help="decoder output embedding dimension") + parser.add_argument("--adaptive-softmax-cutoff", metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion") + parser.add_argument("--share-embed", type=lambda x: options.eval_bool(x), - help='share input and output embeddings') - parser.add_argument('--is-wordlm', action='store_true', - help='whether it is word LM or subword LM. Only ' - 'relevant for ASR decoding with LM, and it determines ' - 'how the underlying decoder instance gets the dictionary' - 'from the task instance when calling cls.build_model()') + help="share input and output embeddings") + parser.add_argument("--is-wordlm", action="store_true", + help="whether it is word LM or subword LM. Only " + "relevant for ASR decoding with LM, and it determines " + "how the underlying decoder instance gets the dictionary " + "from the task instance when calling cls.build_model()") # Granular dropout settings (if not specified these default to --dropout) - parser.add_argument('--decoder-dropout-in', type=float, metavar='D', - help='dropout probability for decoder input embedding') - parser.add_argument('--decoder-dropout-out', type=float, metavar='D', - help='dropout probability for decoder output') + parser.add_argument("--decoder-dropout-in", type=float, metavar="D", + help="dropout probability for decoder input embedding") + parser.add_argument("--decoder-dropout-out", type=float, metavar="D", + help="dropout probability for decoder output") # fmt: on @classmethod @@ -67,10 +67,10 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_lm_architecture(args) - if getattr(args, 'max_target_positions', None) is not None: + if getattr(args, "max_target_positions", None) is not None: max_target_positions = args.max_target_positions else: - max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) + max_target_positions = getattr(args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -80,7 +80,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): utils.print_embed_overlap(embed_dict, dictionary) return utils.load_embedding(embed_dict, dictionary, embed_tokens) - if args.is_wordlm and hasattr(task, 'word_dictionary'): + if args.is_wordlm and hasattr(task, "word_dictionary"): dictionary = task.word_dictionary elif isinstance(task, SpeechRecognitionEspressoTask): dictionary = task.target_dictionary @@ -99,8 +99,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.share_embed and ( args.decoder_embed_dim != args.decoder_out_embed_dim): raise ValueError( - '--share-embed requires ' - '--decoder-embed-dim to match --decoder-out-embed-dim' + "--share-embed requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" ) if args.decoder_freeze_embed: @@ -120,64 +120,64 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_embed, adaptive_softmax_cutoff=( options.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == 'adaptive_loss' else None + if args.criterion == "adaptive_loss" else None ), max_target_positions=max_target_positions, ) return cls(decoder, args) -@register_model_architecture('lstm_lm_espresso', 'lstm_lm_espresso') +@register_model_architecture("lstm_lm_espresso", "lstm_lm_espresso") def base_lm_architecture(args): - args.dropout = getattr(args, 'dropout', 0.1) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 650) - args.decoder_layers = getattr(args, 'decoder_layers', 2) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 650) - args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', False) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.share_embed = getattr(args, 'share_embed', False) - args.is_wordlm = getattr(args, 'is_wordlm', False) - - -@register_model_architecture('lstm_lm_espresso', 'lstm_lm_wsj') + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 48) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_freeze_embed = getattr(args, "decoder_freeze_embed", False) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 650) + args.decoder_layers = getattr(args, "decoder_layers", 2) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 650) + args.decoder_rnn_residual = getattr(args, "decoder_rnn_residual", False) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.share_embed = getattr(args, "share_embed", False) + args.is_wordlm = getattr(args, "is_wordlm", False) + + +@register_model_architecture("lstm_lm_espresso", "lstm_lm_wsj") def lstm_lm_wsj(args): base_lm_architecture(args) -@register_model_architecture('lstm_lm_espresso', 'lstm_lm_librispeech') +@register_model_architecture("lstm_lm_espresso", "lstm_lm_librispeech") def lstm_lm_librispeech(args): - args.dropout = getattr(args, 'dropout', 0.0) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 800) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 800) - args.decoder_layers = getattr(args, 'decoder_layers', 4) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 800) - args.share_embed = getattr(args, 'share_embed', True) + args.dropout = getattr(args, "dropout", 0.0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 800) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 800) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 800) + args.share_embed = getattr(args, "share_embed", True) base_lm_architecture(args) -@register_model_architecture('lstm_lm_espresso', 'lstm_lm_swbd') +@register_model_architecture("lstm_lm_espresso", "lstm_lm_swbd") def lstm_lm_swbd(args): - args.dropout = getattr(args, 'dropout', 0.3) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1800) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1800) - args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1800) - args.share_embed = getattr(args, 'share_embed', True) + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1800) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1800) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1800) + args.share_embed = getattr(args, "share_embed", True) base_lm_architecture(args) -@register_model_architecture('lstm_lm_espresso', 'lstm_wordlm_wsj') +@register_model_architecture("lstm_lm_espresso", "lstm_wordlm_wsj") def lstm_wordlm_wsj(args): - args.dropout = getattr(args, 'dropout', 0.35) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1200) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1200) - args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1200) - args.share_embed = getattr(args, 'share_embed', True) + args.dropout = getattr(args, "dropout", 0.35) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1200) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1200) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1200) + args.share_embed = getattr(args, "share_embed", True) args.is_wordlm = True base_lm_architecture(args) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index bb1b24c4e..e19364be5 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -4,8 +4,10 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import Dict, List, Optional, Tuple import torch +from torch import Tensor import torch.nn as nn import torch.nn.functional as F @@ -38,7 +40,7 @@ logger = logging.getLogger(__name__) -@register_model('speech_lstm') +@register_model("speech_lstm") class SpeechLSTMModel(FairseqEncoderDecoderModel): def __init__(self, encoder, decoder, pretrained_lm=None): super().__init__(encoder, decoder) @@ -51,79 +53,79 @@ def __init__(self, encoder, decoder, pretrained_lm=None): def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--encoder-conv-channels', type=str, metavar='EXPR', - help='list of encoder convolution\'s out channels') - parser.add_argument('--encoder-conv-kernel-sizes', type=str, metavar='EXPR', - help='list of encoder convolution\'s kernel sizes') - parser.add_argument('--encoder-conv-strides', type=str, metavar='EXPR', - help='list of encoder convolution\'s strides') - parser.add_argument('--encoder-rnn-hidden-size', type=int, metavar='N', - help='encoder rnn\'s hidden size') - parser.add_argument('--encoder-rnn-layers', type=int, metavar='N', - help='number of rnn encoder layers') - parser.add_argument('--encoder-rnn-bidirectional', + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--encoder-conv-channels", type=str, metavar="EXPR", + help="list of encoder convolution\'s out channels") + parser.add_argument("--encoder-conv-kernel-sizes", type=str, metavar="EXPR", + help="list of encoder convolution\'s kernel sizes") + parser.add_argument("--encoder-conv-strides", type=str, metavar="EXPR", + help="list of encoder convolution\'s strides") + parser.add_argument("--encoder-rnn-hidden-size", type=int, metavar="N", + help="encoder rnn\'s hidden size") + parser.add_argument("--encoder-rnn-layers", type=int, metavar="N", + help="number of rnn encoder layers") + parser.add_argument("--encoder-rnn-bidirectional", type=lambda x: options.eval_bool(x), - help='make all rnn layers of encoder bidirectional') - parser.add_argument('--encoder-rnn-residual', + help="make all rnn layers of encoder bidirectional") + parser.add_argument("--encoder-rnn-residual", type=lambda x: options.eval_bool(x), - help='create residual connections for rnn encoder ' - 'layers (starting from the 2nd layer), i.e., the actual ' - 'output of such layer is the sum of its input and output') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-embed-path', type=str, metavar='STR', - help='path to pre-trained decoder embedding') - parser.add_argument('--decoder-freeze-embed', action='store_true', - help='freeze decoder embeddings') - parser.add_argument('--decoder-hidden-size', type=int, metavar='N', - help='decoder hidden size') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='number of decoder layers') - parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', - help='decoder output embedding dimension') - parser.add_argument('--decoder-rnn-residual', + help="create residual connections for rnn encoder " + "layers (starting from the 2nd layer), i.e., the actual " + "output of such layer is the sum of its input and output") + parser.add_argument("--decoder-embed-dim", type=int, metavar="N", + help="decoder embedding dimension") + parser.add_argument("--decoder-embed-path", type=str, metavar="STR", + help="path to pre-trained decoder embedding") + parser.add_argument("--decoder-freeze-embed", action="store_true", + help="freeze decoder embeddings") + parser.add_argument("--decoder-hidden-size", type=int, metavar="N", + help="decoder hidden size") + parser.add_argument("--decoder-layers", type=int, metavar="N", + help="number of decoder layers") + parser.add_argument("--decoder-out-embed-dim", type=int, metavar="N", + help="decoder output embedding dimension") + parser.add_argument("--decoder-rnn-residual", type=lambda x: options.eval_bool(x), - help='create residual connections for rnn decoder ' - 'layers (starting from the 2nd layer), i.e., the actual ' - 'output of such layer is the sum of its input and output') - parser.add_argument('--attention-type', type=str, metavar='STR', - choices=['bahdanau', 'luong'], - help='attention type') - parser.add_argument('--attention-dim', type=int, metavar='N', - help='attention dimension') - parser.add_argument('--need-attention', action='store_true', - help='need to return attention tensor for the caller') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion') - parser.add_argument('--share-decoder-input-output-embed', + help="create residual connections for rnn decoder " + "layers (starting from the 2nd layer), i.e., the actual " + "output of such layer is the sum of its input and output") + parser.add_argument("--attention-type", type=str, metavar="STR", + choices=["bahdanau", "luong"], + help="attention type") + parser.add_argument("--attention-dim", type=int, metavar="N", + help="attention dimension") + parser.add_argument("--need-attention", action="store_true", + help="need to return attention tensor for the caller") + parser.add_argument("--adaptive-softmax-cutoff", metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion") + parser.add_argument("--share-decoder-input-output-embed", type=lambda x: options.eval_bool(x), - help='share decoder input and output embeddings') - parser.add_argument('--pretrained-lm-checkpoint', type=str, metavar='STR', - help='path to load checkpoint from pretrained language model(LM), ' - 'which will be present and kept fixed during training.') + help="share decoder input and output embeddings") + parser.add_argument("--pretrained-lm-checkpoint", type=str, metavar="STR", + help="path to load checkpoint from pretrained language model(LM), " + "which will be present and kept fixed during training.") # Granular dropout settings (if not specified these default to --dropout) - parser.add_argument('--encoder-rnn-dropout-in', type=float, metavar='D', - help='dropout probability for encoder rnn\'s input') - parser.add_argument('--encoder-rnn-dropout-out', type=float, metavar='D', - help='dropout probability for encoder rnn\'s output') - parser.add_argument('--decoder-dropout-in', type=float, metavar='D', - help='dropout probability for decoder input embedding') - parser.add_argument('--decoder-dropout-out', type=float, metavar='D', - help='dropout probability for decoder output') + parser.add_argument("--encoder-rnn-dropout-in", type=float, metavar="D", + help="dropout probability for encoder rnn\'s input") + parser.add_argument("--encoder-rnn-dropout-out", type=float, metavar="D", + help="dropout probability for encoder rnn\'s output") + parser.add_argument("--decoder-dropout-in", type=float, metavar="D", + help="dropout probability for decoder input embedding") + parser.add_argument("--decoder-dropout-out", type=float, metavar="D", + help="dropout probability for decoder output") # Scheduled sampling options - parser.add_argument('--scheduled-sampling-probs', type=lambda p: options.eval_str_list(p), - metavar='P_1,P_2,...,P_N', default=[1.0], - help='scheduled sampling probabilities of sampling the truth ' - 'labels for N epochs starting from --start-schedule-sampling-epoch; ' - 'all later epochs using P_N') - parser.add_argument('--start-scheduled-sampling-epoch', type=int, - metavar='N', default=1, - help='start scheduled sampling from the specified epoch') + parser.add_argument("--scheduled-sampling-probs", type=lambda p: options.eval_str_list(p), + metavar="P_1,P_2,...,P_N", default=[1.0], + help="scheduled sampling probabilities of sampling the truth " + "labels for N epochs starting from --start-schedule-sampling-epoch; " + "all later epochs using P_N") + parser.add_argument("--start-scheduled-sampling-epoch", type=int, + metavar="N", default=1, + help="start scheduled sampling from the specified epoch") # fmt: on @classmethod @@ -132,8 +134,8 @@ def build_model(cls, args, task): # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) - max_source_positions = getattr(args, 'max_source_positions', DEFAULT_MAX_SOURCE_POSITIONS) - max_target_positions = getattr(args, 'max_target_positions', DEFAULT_MAX_TARGET_POSITIONS) + max_source_positions = getattr(args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS) + max_target_positions = getattr(args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -155,8 +157,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): if args.share_decoder_input_output_embed and ( args.decoder_embed_dim != args.decoder_out_embed_dim): raise ValueError( - '--share-decoder-input-output-embed requires ' - '--decoder-embed-dim to match --decoder-out-embed-dim' + "--share-decoder-input-output-embed requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" ) if args.decoder_freeze_embed: @@ -165,7 +167,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) strides = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) - logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) + logger.info("input feature dimension: {}, channels: {}".format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, @@ -217,14 +219,14 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( options.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == 'adaptive_loss' else None + if args.criterion == "adaptive_loss" else None ), max_target_positions=max_target_positions, scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, ) pretrained_lm = None if args.pretrained_lm_checkpoint: - logger.info('loading pretrained LM from {}'.format(args.pretrained_lm_checkpoint)) + logger.info("loading pretrained LM from {}".format(args.pretrained_lm_checkpoint)) pretrained_lm = checkpoint_utils.load_model_ensemble( args.pretrained_lm_checkpoint, task=task)[0][0] pretrained_lm.make_generation_fast_() @@ -233,6 +235,21 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): param.requires_grad = False return cls(encoder, decoder, pretrained_lm) + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + epoch=1, + ): + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens, encoder_out=encoder_out, + incremental_state=incremental_state, epoch=epoch, + ) + return decoder_out + def set_num_updates(self, num_updates): self.num_updates = num_updates super().set_num_updates(num_updates) @@ -342,7 +359,7 @@ def output_lengths(self, in_lengths): return in_lengths if self.conv_layers_before is None \ else self.conv_layers_before.output_lengths(in_lengths) - def forward(self, src_tokens, src_lengths, **unused): + def forward(self, src_tokens, src_lengths: Tensor, **unused): if self.left_pad: # nn.utils.rnn.pack_padded_sequence requires right-padding; # convert left-padding to right-padding @@ -372,13 +389,13 @@ def forward(self, src_tokens, src_lengths, **unused): if self.residual and i > 0: # residual connection starts from the 2nd layer prev_x = x # pack embedded source tokens into a PackedSequence - packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist()) + packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data) # apply LSTM packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0)) # unpack outputs and apply dropout - x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value) + x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value*1.0) if i < len(self.lstm) - 1: # not applying dropout for the last layer x = F.dropout(x, p=self.dropout_out, training=self.training) x = x + prev_x if self.residual and i > 0 else x @@ -387,18 +404,20 @@ def forward(self, src_tokens, src_lengths, **unused): encoder_padding_mask = padding_mask.t() return { - 'encoder_out': (x,), - 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None + "encoder_out": (x, src_lengths), + "encoder_padding_mask": (encoder_padding_mask if encoder_padding_mask.any() else None, torch.empty(0)), } def reorder_encoder_out(self, encoder_out, new_order): - encoder_out['encoder_out'] = tuple( - eo.index_select(1, new_order) - for eo in encoder_out['encoder_out'] + encoder_out["encoder_out"] = ( + encoder_out["encoder_out"][0].index_select(1, new_order), + encoder_out["encoder_out"][1].index_select(0, new_order), ) - if encoder_out['encoder_padding_mask'] is not None: - encoder_out['encoder_padding_mask'] = \ - encoder_out['encoder_padding_mask'].index_select(1, new_order) + if encoder_out["encoder_padding_mask"][0] is not None: + encoder_out["encoder_padding_mask"] = ( + encoder_out["encoder_padding_mask"][0].index_select(1, new_order), + encoder_out['encoder_padding_mask'][1], + ) return encoder_out def max_positions(self): @@ -421,13 +440,14 @@ def __init__( self.dropout_out = dropout_out self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed - if attn_type is None or attn_type.lower() == 'none': + if attn_type is None or attn_type.lower() == "none": # no attention, no encoder output needed (language model case) need_attn = False encoder_output_units = 0 self.need_attn = need_attn self.residual = residual self.max_target_positions = max_target_positions + self.num_layers = num_layers self.adaptive_softmax = None num_embeddings = len(dictionary) @@ -446,18 +466,18 @@ def __init__( ) for layer in range(num_layers) ]) - if attn_type is None or attn_type.lower() == 'none': + if attn_type is None or attn_type.lower() == "none": self.attention = None - elif attn_type.lower() == 'bahdanau': + elif attn_type.lower() == "bahdanau": self.attention = speech_attention.BahdanauAttention( hidden_size, encoder_output_units, attn_dim, ) - elif attn_type.lower() == 'luong': + elif attn_type.lower() == "luong": self.attention = speech_attention.LuongAttention( hidden_size, encoder_output_units, ) else: - raise ValueError('unrecognized attention type.') + raise ValueError("unrecognized attention type.") if hidden_size + encoder_output_units != out_embed_dim: self.additional_fc = Linear(hidden_size + encoder_output_units, out_embed_dim) if adaptive_softmax_cutoff is not None: @@ -469,7 +489,25 @@ def __init__( self.scheduled_sampling_rate_scheduler = scheduled_sampling_rate_scheduler - def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs): + def get_cached_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): + cached_state = self.get_incremental_state(incremental_state, "cached_state") + assert cached_state is not None + prev_hiddens_ = cached_state["prev_hiddens"] + assert prev_hiddens_ is not None + prev_cells_ = cached_state["prev_cells"] + assert prev_cells_ is not None + prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] + prev_cells = [prev_cells_[j] for j in range(self.num_layers)] + input_feed = cached_state["input_feed"] # can be None for decoder-only language models + return prev_hiddens, prev_cells, input_feed + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + **kwargs, + ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape @@ -485,7 +523,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, - attention weights of shape `(batch, tgt_len, src_len)` """ if self.scheduled_sampling_rate_scheduler is not None: - epoch = kwargs.get('epoch', 1) + epoch = kwargs.get("epoch", 1) sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) if sampling_prob < 1.0: # apply scheduled sampling return self._forward_with_scheduled_sampling( @@ -499,7 +537,11 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, return self.output_layer(x), attn_scores def _forward_with_scheduled_sampling( - self, prev_output_tokens, sampling_prob, encoder_out=None, incremental_state=None, + self, + prev_output_tokens, + sampling_prob, + encoder_out: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): bsz, seqlen = prev_output_tokens.size() outs = [] @@ -523,7 +565,11 @@ def _forward_with_scheduled_sampling( return x, None def extract_features( - self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused, + self, + prev_output_tokens, + encoder_out: Dict[str, Tuple[Tensor, Tensor]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + **unused, ): """ Similar to *forward* but only return features. @@ -535,18 +581,19 @@ def extract_features( """ if encoder_out is not None: assert self.attention is not None - encoder_padding_mask = encoder_out['encoder_padding_mask'] - encoder_out = encoder_out['encoder_out'] + encoder_padding_mask, *_ = encoder_out["encoder_padding_mask"] + encoder_out = encoder_out["encoder_out"] # get outputs from encoder encoder_outs = encoder_out[0] srclen = encoder_outs.size(0) else: encoder_padding_mask = None encoder_out = None - srclen = None + srclen = 0 - if incremental_state is not None: + if incremental_state is not None and len(incremental_state) > 0: prev_output_tokens = prev_output_tokens[:, -1:] + bsz, seqlen = prev_output_tokens.size() # embed tokens @@ -557,14 +604,13 @@ def extract_features( x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) - cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') - if cached_state is not None: - prev_hiddens, prev_cells, input_feed = cached_state + + if incremental_state is not None and len(incremental_state) > 0: + prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) else: - num_layers = len(self.layers) zero_state = x.new_zeros(bsz, self.hidden_size) - prev_hiddens = [zero_state for i in range(num_layers)] - prev_cells = [zero_state for i in range(num_layers)] + prev_hiddens = [zero_state for i in range(self.num_layers)] + prev_cells = [zero_state for i in range(self.num_layers)] input_feed = x.new_zeros(bsz, self.encoder_output_units) \ if encoder_out is not None else None @@ -586,6 +632,7 @@ def extract_features( # compute and apply attention using the 1st layer's hidden state if encoder_out is not None: if i == 0: + assert attn_scores is not None context, attn_scores[:, j, :], _ = self.attention( hidden, encoder_outs, encoder_padding_mask, ) @@ -614,11 +661,14 @@ def extract_features( # save final output outs.append(input) - # cache previous states (no-op except during incremental generation) - utils.set_incremental_state( - self, incremental_state, 'cached_state', - (prev_hiddens, prev_cells, input_feed), + # Stack all the necessary tensors together and store + prev_hiddens_tensor = torch.stack(prev_hiddens) + prev_cells_tensor = torch.stack(prev_cells) + cache_state = torch.jit.annotate( + Dict[str, Optional[Tensor]], + {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": input_feed} ) + self.set_incremental_state(incremental_state, "cached_state", cache_state) # collect outputs across time steps x = torch.cat(outs, dim=0).view(seqlen, bsz, -1) @@ -627,12 +677,12 @@ def extract_features( # T x B x C -> B x T x C x = x.transpose(1, 0) - if hasattr(self, 'additional_fc') and self.adaptive_softmax is None: + if hasattr(self, "additional_fc") and self.adaptive_softmax is None: x = self.additional_fc(x) x = F.dropout(x, p=self.dropout_out, training=self.training) - # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and encoder_out is not None and self.need_attn: + assert attn_scores is not None attn_scores = attn_scores.transpose(0, 2) else: attn_scores = None @@ -650,48 +700,62 @@ def output_layer(self, features, **kwargs): else: return features - def reorder_incremental_state(self, incremental_state, new_order): + def reorder_state(self, state: List[Tensor], new_order): + return [ + state_i.index_select(0, new_order) if state_i is not None else None + for state_i in state + ] + + def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order): super().reorder_incremental_state(incremental_state, new_order) - cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') - if cached_state is None: + if incremental_state is None or len(incremental_state) == 0: return - - def reorder_state(state): - if isinstance(state, list): - return [reorder_state(state_i) for state_i in state] - elif state is not None: - return state.index_select(0, new_order) - else: - return None - - new_state = tuple(map(reorder_state, cached_state)) - utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) + prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) + cached_state = (prev_hiddens, prev_cells, [input_feed]) + new_state = [self.reorder_state(state, new_order) for state in cached_state] + prev_hiddens_tensor = torch.stack(new_state[0]) + prev_cells_tensor = torch.stack(new_state[1]) + cached_state_new = torch.jit.annotate( + Dict[str, Optional[Tensor]], + {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]} + ) + self.set_incremental_state(incremental_state, "cached_state", cached_state_new), + return def masked_copy_incremental_state(self, incremental_state, another_cached_state, mask): - cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') - if cached_state is None: - assert another_cached_state is None + if incremental_state is None or len(incremental_state) == 0: + assert another_cached_state is None or len(another_cached_state) == 0 return - - def mask_copy_state(state, another_state): - if isinstance(state, list): - assert isinstance(another_state, list) and len(state) == len(another_state) - return [ - mask_copy_state(state_i, another_state_i) - for state_i, another_state_i in zip(state, another_state) - ] - if state is not None: - assert state.size(0) == mask.size(0) and another_state is not None and \ - state.size() == another_state.size() - for _ in range(1, len(state.size())): - mask_unsqueezed = mask.unsqueeze(-1) - return torch.where(mask_unsqueezed, state, another_state) - else: - assert another_state is None - return None - - new_state = tuple(map(mask_copy_state, cached_state, another_cached_state)) - utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) + prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) + cached_state = (prev_hiddens, prev_cells, [input_feed]) + another_cached_state = (another_cached_state[0], another_cached_state[1], [another_cached_state[2]]) + + def mask_copy_state(state: List[Tensor], another_state: List[Tensor]): + new_state = [] + for state_i, another_state_i in zip(state, another_state): + if state_i is None: + assert another_state_i is None + new_state.append(None) + else: + assert state_i.size(0) == mask.size(0) and another_state_i is not None and \ + state_i.size() == another_state_i.size() + mask_unsqueezed = mask + for _ in range(1, len(state_i.size())): + mask_unsqueezed = mask_unsqueezed.unsqueeze(-1) + new_state.append(torch.where(mask_unsqueezed, state_i, another_state_i)) + return new_state + + new_state = [ + mask_copy_state(state, another_state) + for (state, another_state) in zip(cached_state, another_cached_state) + ] + prev_hiddens_tensor = torch.stack(new_state[0]) + prev_cells_tensor = torch.stack(new_state[1]) + cached_state_new = torch.jit.annotate( + Dict[str, Optional[Tensor]], + {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]} + ) + self.set_incremental_state(incremental_state, "cached_state", cached_state_new) def max_positions(self): """Maximum output length supported by the decoder.""" @@ -724,71 +788,71 @@ def Convolution2d(in_channels, out_channels, kernel_size, stride): return m -@register_model_architecture('speech_lstm', 'speech_lstm') +@register_model_architecture("speech_lstm", "speech_lstm") def base_architecture(args): - args.dropout = getattr(args, 'dropout', 0.4) + args.dropout = getattr(args, "dropout", 0.4) args.encoder_conv_channels = getattr( - args, 'encoder_conv_channels', '[64, 64, 128, 128]', + args, "encoder_conv_channels", "[64, 64, 128, 128]", ) args.encoder_conv_kernel_sizes = getattr( - args, 'encoder_conv_kernel_sizes', '[(3, 3), (3, 3), (3, 3), (3, 3)]', + args, "encoder_conv_kernel_sizes", "[(3, 3), (3, 3), (3, 3), (3, 3)]", ) args.encoder_conv_strides = getattr( - args, 'encoder_conv_strides', '[(1, 1), (2, 2), (1, 1), (2, 2)]', + args, "encoder_conv_strides", "[(1, 1), (2, 2), (1, 1), (2, 2)]", ) - args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 320) - args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 3) - args.encoder_rnn_bidirectional = getattr(args, 'encoder_rnn_bidirectional', True) - args.encoder_rnn_residual = getattr(args, 'encoder_rnn_residual', False) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 48) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 320) - args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 960) - args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', True) - args.attention_type = getattr(args, 'attention_type', 'bahdanau') - args.attention_dim = getattr(args, 'attention_dim', 320) - args.need_attention = getattr(args, 'need_attention', False) - args.encoder_rnn_dropout_in = getattr(args, 'encoder_rnn_dropout_in', args.dropout) - args.encoder_rnn_dropout_out = getattr(args, 'encoder_rnn_dropout_out', args.dropout) - args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) - args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.pretrained_lm_checkpoint = getattr(args, 'pretrained_lm_checkpoint', None) - - -@register_model_architecture('speech_lstm', 'speech_conv_lstm_wsj') + args.encoder_rnn_hidden_size = getattr(args, "encoder_rnn_hidden_size", 320) + args.encoder_rnn_layers = getattr(args, "encoder_rnn_layers", 3) + args.encoder_rnn_bidirectional = getattr(args, "encoder_rnn_bidirectional", True) + args.encoder_rnn_residual = getattr(args, "encoder_rnn_residual", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 48) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_freeze_embed = getattr(args, "decoder_freeze_embed", False) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 320) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 960) + args.decoder_rnn_residual = getattr(args, "decoder_rnn_residual", True) + args.attention_type = getattr(args, "attention_type", "bahdanau") + args.attention_dim = getattr(args, "attention_dim", 320) + args.need_attention = getattr(args, "need_attention", False) + args.encoder_rnn_dropout_in = getattr(args, "encoder_rnn_dropout_in", args.dropout) + args.encoder_rnn_dropout_out = getattr(args, "encoder_rnn_dropout_out", args.dropout) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", False) + args.pretrained_lm_checkpoint = getattr(args, "pretrained_lm_checkpoint", None) + + +@register_model_architecture("speech_lstm", "speech_conv_lstm_wsj") def conv_lstm_wsj(args): base_architecture(args) -@register_model_architecture('speech_lstm', 'speech_conv_lstm_librispeech') +@register_model_architecture("speech_lstm", "speech_conv_lstm_librispeech") def speech_conv_lstm_librispeech(args): - args.dropout = getattr(args, 'dropout', 0.3) - args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 1024) - args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 4) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 1024) - args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 3072) - args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', True) - args.attention_type = getattr(args, 'attention_type', 'bahdanau') - args.attention_dim = getattr(args, 'attention_dim', 512) + args.dropout = getattr(args, "dropout", 0.3) + args.encoder_rnn_hidden_size = getattr(args, "encoder_rnn_hidden_size", 1024) + args.encoder_rnn_layers = getattr(args, "encoder_rnn_layers", 4) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1024) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 3072) + args.decoder_rnn_residual = getattr(args, "decoder_rnn_residual", True) + args.attention_type = getattr(args, "attention_type", "bahdanau") + args.attention_dim = getattr(args, "attention_dim", 512) base_architecture(args) -@register_model_architecture('speech_lstm', 'speech_conv_lstm_swbd') +@register_model_architecture("speech_lstm", "speech_conv_lstm_swbd") def speech_conv_lstm_swbd(args): - args.dropout = getattr(args, 'dropout', 0.5) - args.encoder_rnn_hidden_size = getattr(args, 'encoder_rnn_hidden_size', 640) - args.encoder_rnn_layers = getattr(args, 'encoder_rnn_layers', 4) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 640) - args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 640) - args.decoder_layers = getattr(args, 'decoder_layers', 3) - args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1920) - args.decoder_rnn_residual = getattr(args, 'decoder_rnn_residual', True) - args.attention_type = getattr(args, 'attention_type', 'bahdanau') - args.attention_dim = getattr(args, 'attention_dim', 640) + args.dropout = getattr(args, "dropout", 0.5) + args.encoder_rnn_hidden_size = getattr(args, "encoder_rnn_hidden_size", 640) + args.encoder_rnn_layers = getattr(args, "encoder_rnn_layers", 4) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 640) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 640) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1920) + args.decoder_rnn_residual = getattr(args, "decoder_rnn_residual", True) + args.attention_type = getattr(args, "attention_type", "bahdanau") + args.attention_dim = getattr(args, "attention_dim", 640) base_architecture(args) diff --git a/espresso/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py index 872e38fe8..00cabc3f3 100644 --- a/espresso/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -7,7 +7,6 @@ import torch from fairseq.models import FairseqLanguageModel, FairseqIncrementalDecoder -from fairseq import utils from espresso.data import AsrDictionary from espresso.models.external_language_model import RawOutExternalLanguageModelBase @@ -96,7 +95,7 @@ def forward(self, # Move the batched state to the next state according to the automaton batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) # B[Batch] - cached_state = utils.get_incremental_state(self.lm_decoder, incremental_state, 'cached_state') + cached_state = self.lm_decoder.get_incremental_state(incremental_state, 'cached_state') if cached_state is None: # First step assert (prev_output_tokens == self.subword_eos_idx).all(), \ @@ -110,15 +109,15 @@ def forward(self, nodes: torch.Tensor = prev_output_tokens.new_full([bsz], self.tree.root_id) # Z_NodeId[Batch] else: # Not the first step - cumsum_probs: torch.Tensor = utils.get_incremental_state( - self, incremental_state, 'cumsum_probs', + cumsum_probs: torch.Tensor = self.get_incremental_state( + incremental_state, 'cumsum_probs', ) # R[Batch, 1, Vocab] - nodes: torch.Tensor = utils.get_incremental_state(self, incremental_state, 'nodes') # Z_NodeId[Batch] + nodes: torch.Tensor = self.get_incremental_state(incremental_state, 'nodes') # Z_NodeId[Batch] assert nodes.size(0) == bsz w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(1) # Z[Batch, Len=1] w[w < 0] = self.word_unk_idx - old_cached_state = _clone_cached_state(cached_state) + old_cached_state = _clone_cached_state(self.lm_decoder.get_cached_state(incremental_state)) # recompute cumsum_probs from inter-word transition probabilities # only for those whose prev_output_token is lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( @@ -140,8 +139,8 @@ def forward(self, all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] - utils.set_incremental_state(self, incremental_state, 'cumsum_probs', cumsum_probs) - utils.set_incremental_state(self, incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, 'cumsum_probs', cumsum_probs) + self.set_incremental_state(incremental_state, 'nodes', nodes) # Compute probabilities # initialize out_probs [Batch, 1, Vocab] @@ -227,20 +226,15 @@ def forward(self, def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - cumsum_probs = utils.get_incremental_state( - self, incremental_state, 'cumsum_probs') + cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - utils.set_incremental_state( - self, incremental_state, 'cumsum_probs', new_cumsum_probs, - ) + self.set_incremental_state(incremental_state, 'cumsum_probs', new_cumsum_probs) - nodes = utils.get_incremental_state(self, incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, 'nodes') if nodes is not None: new_nodes = nodes.index_select(0, new_order) - utils.set_incremental_state( - self, incremental_state, 'nodes', new_nodes, - ) + self.set_incremental_state(incremental_state, 'nodes', new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" diff --git a/examples/asr_librispeech/cmd.sh b/examples/asr_librispeech/cmd.sh index 9e73f25da..cf1ed0174 100644 --- a/examples/asr_librispeech/cmd.sh +++ b/examples/asr_librispeech/cmd.sh @@ -14,7 +14,7 @@ #export cuda_cmd="run.pl --mem 10G --gpu 1" #export decode_cmd="run.pl --mem 4G" -# JHU setup +# JHU setup (copy queue-freegpu.pl from ESPnet into utils/) export train_cmd="queue.pl --mem 10G" -export cuda_cmd="queue.pl --mem 10G --gpu 1 --config conf/gpu.conf" +export cuda_cmd="queue-freegpu.pl --mem 10G --gpu 1 --config conf/gpu.conf" export decode_cmd="queue.pl --mem 4G" diff --git a/examples/asr_swbd/cmd.sh b/examples/asr_swbd/cmd.sh index b14280b96..e531b4431 100644 --- a/examples/asr_swbd/cmd.sh +++ b/examples/asr_swbd/cmd.sh @@ -14,7 +14,7 @@ #export cuda_cmd="run.pl --mem 4G --gpu 1" #export decode_cmd="run.pl --mem 4G" -# JHU setup +# JHU setup (copy queue-freegpu.pl from ESPnet into utils/) export train_cmd="queue.pl --mem 4G" -export cuda_cmd="queue.pl --mem 4G --gpu 1 --config conf/gpu.conf" +export cuda_cmd="queue-freegpu.pl --mem 8G --gpu 1 --config conf/gpu.conf" export decode_cmd="queue.pl --mem 4G" diff --git a/examples/asr_wsj/cmd.sh b/examples/asr_wsj/cmd.sh index b14280b96..e531b4431 100644 --- a/examples/asr_wsj/cmd.sh +++ b/examples/asr_wsj/cmd.sh @@ -14,7 +14,7 @@ #export cuda_cmd="run.pl --mem 4G --gpu 1" #export decode_cmd="run.pl --mem 4G" -# JHU setup +# JHU setup (copy queue-freegpu.pl from ESPnet into utils/) export train_cmd="queue.pl --mem 4G" -export cuda_cmd="queue.pl --mem 4G --gpu 1 --config conf/gpu.conf" +export cuda_cmd="queue-freegpu.pl --mem 8G --gpu 1 --config conf/gpu.conf" export decode_cmd="queue.pl --mem 4G" From b35d2106adaa09366ab6cb9302513131270ecd3a Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 23 Apr 2020 16:46:01 -0400 Subject: [PATCH 079/119] use EncoderOut for SpeechLSTMEncoder's output; code adaptation/changes according to the commits on Apr 21 --- espresso/models/speech_lstm.py | 62 +++++++++++++------------ espresso/speech_train.py | 24 ++++++++-- espresso/tools/simple_greedy_decoder.py | 4 +- 3 files changed, 55 insertions(+), 35 deletions(-) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index e19364be5..815070cad 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -20,6 +20,7 @@ register_model, register_model_architecture, ) +from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import ( Embedding, LSTM, @@ -403,22 +404,24 @@ def forward(self, src_tokens, src_lengths: Tensor, **unused): encoder_padding_mask = padding_mask.t() - return { - "encoder_out": (x, src_lengths), - "encoder_padding_mask": (encoder_padding_mask if encoder_padding_mask.any() else None, torch.empty(0)), - } + return EncoderOut( + encoder_out=x, # T x B x C + encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=src_lengths, # B + ) - def reorder_encoder_out(self, encoder_out, new_order): - encoder_out["encoder_out"] = ( - encoder_out["encoder_out"][0].index_select(1, new_order), - encoder_out["encoder_out"][1].index_select(0, new_order), + def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + return EncoderOut( + encoder_out=encoder_out.encoder_out.index_select(1, new_order), + encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(1, new_order) if encoder_out.encoder_padding_mask is not None else None, + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=encoder_out.src_lengths.index_select(0, new_order), ) - if encoder_out["encoder_padding_mask"][0] is not None: - encoder_out["encoder_padding_mask"] = ( - encoder_out["encoder_padding_mask"][0].index_select(1, new_order), - encoder_out['encoder_padding_mask'][1], - ) - return encoder_out def max_positions(self): """Maximum input length supported by the encoder.""" @@ -466,6 +469,7 @@ def __init__( ) for layer in range(num_layers) ]) + if attn_type is None or attn_type.lower() == "none": self.attention = None elif attn_type.lower() == "bahdanau": @@ -478,12 +482,15 @@ def __init__( ) else: raise ValueError("unrecognized attention type.") + if hidden_size + encoder_output_units != out_embed_dim: self.additional_fc = Linear(hidden_size + encoder_output_units, out_embed_dim) + if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined - self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, hidden_size, adaptive_softmax_cutoff, - dropout=dropout_out) + self.adaptive_softmax = AdaptiveSoftmax( + num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out + ) elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) @@ -504,7 +511,7 @@ def get_cached_state(self, incremental_state: Optional[Dict[str, Dict[str, Optio def forward( self, prev_output_tokens, - encoder_out: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, + encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, **kwargs, ): @@ -512,7 +519,7 @@ def forward( Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing - encoder_out (Tensor, optional): output from the encoder, used for + encoder_out (EncoderOut, optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` @@ -540,7 +547,7 @@ def _forward_with_scheduled_sampling( self, prev_output_tokens, sampling_prob, - encoder_out: Optional[Dict[str, Tuple[Tensor, Tensor]]] = None, + encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): bsz, seqlen = prev_output_tokens.size() @@ -567,7 +574,7 @@ def _forward_with_scheduled_sampling( def extract_features( self, prev_output_tokens, - encoder_out: Dict[str, Tuple[Tensor, Tensor]], + encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, **unused, ): @@ -579,17 +586,15 @@ def extract_features( - the decoder's features of shape `(batch, tgt_len, embed_dim)` - attention weights of shape `(batch, tgt_len, src_len)` """ + # get outputs from encoder if encoder_out is not None: assert self.attention is not None - encoder_padding_mask, *_ = encoder_out["encoder_padding_mask"] - encoder_out = encoder_out["encoder_out"] - # get outputs from encoder - encoder_outs = encoder_out[0] - srclen = encoder_outs.size(0) + encoder_outs = encoder_out.encoder_out + encoder_padding_mask = encoder_out.encoder_padding_mask else: - encoder_padding_mask = None - encoder_out = None - srclen = 0 + encoder_outs = torch.empty(0) + encoder_padding_mask = torch.empty(0) + srclen = encoder_outs.size(0) if incremental_state is not None and len(incremental_state) > 0: prev_output_tokens = prev_output_tokens[:, -1:] @@ -604,7 +609,6 @@ def extract_features( x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) - if incremental_state is not None and len(incremental_state) > 0: prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) else: diff --git a/espresso/speech_train.py b/espresso/speech_train.py index aadd3a362..b381e6420 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -17,7 +17,14 @@ import numpy as np import torch -from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils +from fairseq import ( + checkpoint_utils, + distributed_utils, + options, + quantization_utils, + tasks, + utils, +) from fairseq.data import iterators from fairseq.logging import meters, metrics, progress_bar from fairseq.trainer import Trainer @@ -71,9 +78,19 @@ def main(args, init_distributed=False): sum(p.numel() for p in model.parameters() if p.requires_grad), )) + # (optionally) Configure quantization + if args.quantization_config_path is not None: + quantizer = quantization_utils.Quantizer( + config_path=args.quantization_config_path, + max_epoch=args.max_epoch, + max_update=args.max_update, + ) + else: + quantizer = None + # Build trainer if args.model_parallel_size == 1: - trainer = Trainer(args, task, model, criterion) + trainer = Trainer(args, task, model, criterion, quantizer) else: trainer = MegatronTrainer(args, task, model, criterion) @@ -163,8 +180,7 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) - # task specific setup per epoch - task.begin_epoch(epoch_itr.epoch, trainer.get_model()) + trainer.begin_epoch(epoch_itr.epoch) if hasattr(trainer.criterion, 'set_epoch'): trainer.criterion.set_epoch(epoch_itr.epoch) diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index 84a1f3403..95122ca7f 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -88,7 +88,7 @@ def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] target = sample["target"] # target can only be None if not for validation assert target is not None or not self.for_validation - max_encoder_output_length = encoder_outs[0]["encoder_out"][0].size(0) + max_encoder_output_length = encoder_outs[0].encoder_out.size(0) # for validation, make the maximum decoding length equal to at least the # length of target, and the length of encoder_out if possible; otherwise # max_len is obtained from max_len_a/b @@ -103,7 +103,7 @@ def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] tokens = src_tokens.new(bsz, max_len + 2).long().fill_(self.pad) tokens[:, 0] = self.eos if bos_token is None else bos_token # lprobs is only used when target is not None (i.e., for validation) - lprobs = encoder_outs[0]["encoder_out"][0].new_full( + lprobs = encoder_outs[0].encoder_out.new_full( (bsz, target.size(1), self.vocab_size), -np.log(self.vocab_size), ) if self.for_validation else None attn = None From cfa899c55df02d425d5f12bf73df8566256a6a12 Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Sat, 2 May 2020 19:53:55 -0400 Subject: [PATCH 080/119] Hybrid ASR code (E2E LF-MMI and cross-entropy) and WSJ examples (#29) --- espresso/criterions/lf_mmi_loss.py | 102 ++++ .../subsampled_cross_entropy_with_accuracy.py | 112 ++++ espresso/data/__init__.py | 6 + espresso/data/asr_chain_dataset.py | 316 +++++++++++ espresso/data/asr_xent_dataset.py | 517 ++++++++++++++++++ espresso/data/feat_text_dataset.py | 27 +- espresso/dump_posteriors.py | 206 +++++++ espresso/models/speech_lstm.py | 6 +- espresso/models/speech_lstm_encoder_model.py | 229 ++++++++ espresso/models/speech_tdnn.py | 302 ++++++++++ espresso/speech_train.py | 4 + espresso/tasks/speech_recognition.py | 2 +- espresso/tasks/speech_recognition_hybrid.py | 382 +++++++++++++ espresso/tools/.gitignore | 2 + espresso/tools/Makefile | 65 ++- espresso/tools/asr_prep_json.py | 12 + ...ate_initial_state_prior_from_alignments.py | 68 +++ .../tools/generate_log_probs_for_decoding.py | 68 +++ examples/asr_wsj/conf/mfcc_hires.conf | 10 + examples/asr_wsj/local/common_data_prep.sh | 76 +++ examples/asr_wsj/local/score.sh | 1 + examples/asr_wsj/path.sh | 3 +- examples/asr_wsj/run_chain_e2e.sh | 238 ++++++++ examples/asr_wsj/run_xent.sh | 217 ++++++++ fairseq/sequence_generator.py | 4 +- 25 files changed, 2957 insertions(+), 18 deletions(-) create mode 100644 espresso/criterions/lf_mmi_loss.py create mode 100644 espresso/criterions/subsampled_cross_entropy_with_accuracy.py create mode 100644 espresso/data/asr_chain_dataset.py create mode 100644 espresso/data/asr_xent_dataset.py create mode 100755 espresso/dump_posteriors.py create mode 100644 espresso/models/speech_lstm_encoder_model.py create mode 100644 espresso/models/speech_tdnn.py create mode 100644 espresso/tasks/speech_recognition_hybrid.py create mode 100755 espresso/tools/estimate_initial_state_prior_from_alignments.py create mode 100644 espresso/tools/generate_log_probs_for_decoding.py create mode 100644 examples/asr_wsj/conf/mfcc_hires.conf create mode 100755 examples/asr_wsj/local/common_data_prep.sh create mode 120000 examples/asr_wsj/local/score.sh create mode 100755 examples/asr_wsj/run_chain_e2e.sh create mode 100755 examples/asr_wsj/run_xent.sh diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py new file mode 100644 index 000000000..e7636a59b --- /dev/null +++ b/espresso/criterions/lf_mmi_loss.py @@ -0,0 +1,102 @@ +# Copyright (c) Yiming Wang, Yiwen Shao +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.logging import metrics + + +@register_criterion("lattice_free_mmi") +class LatticeFreeMMICriterion(FairseqCriterion): + + def __init__( + self, task, sentence_avg, denominator_fst_path, + den_leaky_hmm_coefficient, num_leaky_hmm_coefficient, + ): + super().__init__(task) + try: + from pychain.graph import ChainGraph + import simplefst + except ImportError: + raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") + + self.sentence_avg = sentence_avg + den_fst = simplefst.StdVectorFst.read(denominator_fst_path) + self.den_graph = ChainGraph(den_fst, leaky_mode="transition") + self.den_leaky_hmm_coefficient = den_leaky_hmm_coefficient + self.num_leaky_hmm_coefficient = num_leaky_hmm_coefficient + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + FairseqCriterion.add_args(parser) + parser.add_argument("--denominator-fst-path", type=str, metavar="FILE", + help="path to the denominator fst file") + parser.add_argument("--den-leaky-hmm-coefficient", default=1.0e-05, type=float, metavar="F", + help="leaky-hmm coefficient for the denominator") + parser.add_argument("--num-leaky-hmm-coefficient", default=1.0e-15, type=float, metavar="F", + help="leaky-hmm coefficient for the numerator") + # fmt: on + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, _ = self.compute_loss(net_output, sample, reduce=reduce) + + sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + def compute_loss(self, net_output, sample, reduce=True): + try: + from pychain.graph import ChainGraphBatch + from pychain.loss import ChainFunction + except ImportError: + raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") + + den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"]) + encoder_out = net_output.encoder_out.transpose(0, 1) # T x B x V -> B x T x V + out_lengths = net_output.src_lengths.long() # B + den_objf = ChainFunction.apply(encoder_out, out_lengths, den_graphs, self.den_leaky_hmm_coefficient) + num_objf = ChainFunction.apply(encoder_out, out_lengths, sample["target"], self.num_leaky_hmm_coefficient) + loss = - num_objf + den_objf # negative log-probs + return loss, loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=7) + if sample_size != ntokens: + metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=7) + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4)) + else: + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg, round=4)) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/espresso/criterions/subsampled_cross_entropy_with_accuracy.py b/espresso/criterions/subsampled_cross_entropy_with_accuracy.py new file mode 100644 index 000000000..8ee620276 --- /dev/null +++ b/espresso/criterions/subsampled_cross_entropy_with_accuracy.py @@ -0,0 +1,112 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn.functional as F + +from fairseq.criterions import register_criterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.logging import metrics + + +logger = logging.getLogger(__name__) + + +@register_criterion("subsampled_cross_entropy_with_accuracy") +class SubsampledCrossEntropyWithAccuracyCriterion(CrossEntropyCriterion): + + def __init__(self, task, sentence_avg): + super().__init__(task, sentence_avg) + self.subsampling_factor = None + # indicate whether to transpose the first two dimensions of net_output + # so that it is B x T x V + self.transpose_net_output = getattr(task, "transpose_net_output", True) + self.state_prior_update_interval = getattr(task, "state_prior_update_interval", None) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, num_corr, num_tot, state_post = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "num_corr": num_corr, + "num_tot": num_tot, + "state_post": state_post, + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + if self.subsampling_factor is None: + self.subsampling_factor = int(round(100 / model.output_lengths(100))) + logger.info("subsampling factor for target labels = {}".format(self.subsampling_factor)) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + if self.transpose_net_output: + lprobs = lprobs.transpose(0, 1).contiguous() # T x B x V -> B x T x V + net_output_length = lprobs.size(1) + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = model.get_targets(sample, net_output) + if self.subsampling_factor > 1: + target = target[:, ::self.subsampling_factor] + target = target[:, :net_output_length] # truncate if necessary + right_pad_length = net_output_length - target.size(1) + if right_pad_length > 0: # pad with the right-most labels on the right + target = torch.cat([target, target[:, -1:].expand(-1, right_pad_length)], 1) + target = target.view(-1) + if not model.training: + # hack for dummy batches, assuming lprobs is longer than targets + lprobs = lprobs[:target.size(0)] + loss = F.nll_loss( + lprobs, + target, + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + + with torch.no_grad(): + mask = target.ne(self.padding_idx) + num_corr = (lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)).int().sum() + num_tot = mask.int().sum() + + state_post = None + if ( + hasattr(model, "num_updates") and model.training and + self.state_prior_update_interval is not None and + model.num_updates // self.state_prior_update_interval > + (model.num_updates - 1) // self.state_prior_update_interval + ): + frame_indices = torch.nonzero(mask, as_tuple=True)[0] + state_post = lprobs.index_select(0, frame_indices).exp().mean(0).detach() + + return loss, num_corr, num_tot, state_post + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + CrossEntropyCriterion.reduce_metrics(logging_outputs) + num_corr = sum(log.get("num_corr", 0) for log in logging_outputs) + num_tot = sum(log.get("num_tot", 0) for log in logging_outputs) + metrics.log_scalar("accuracy", num_corr.float() / num_tot * 100 if num_tot > 0 else 0.0, num_tot, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + # because new_state_prior is not a scalar + return False diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py index 1c3cae450..1fadbbc76 100644 --- a/espresso/data/__init__.py +++ b/espresso/data/__init__.py @@ -3,8 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .asr_chain_dataset import AsrChainDataset, NumeratorGraphDataset from .asr_dataset import AsrDataset from .asr_dictionary import AsrDictionary +from .asr_xent_dataset import AliScpCachedDataset, AsrXentDataset from .feat_text_dataset import ( AsrTextDataset, FeatScpCachedDataset, @@ -13,10 +15,14 @@ ) __all__ = [ + 'AliScpCachedDataset', + 'AsrChainDataset', 'AsrDataset', 'AsrDictionary', 'AsrTextDataset', + 'AsrXentDataset', 'FeatScpCachedDataset', 'FeatScpDataset', 'FeatScpInMemoryDataset', + 'NumeratorGraphDataset', ] diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py new file mode 100644 index 000000000..2a68ef7ed --- /dev/null +++ b/espresso/data/asr_chain_dataset.py @@ -0,0 +1,316 @@ +# Copyright (c) Yiming Wang, Yiwen Shao +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import re +from typing import List + +import numpy as np + +import torch + +from fairseq.data import FairseqDataset + +import espresso.tools.utils as speech_utils + +logger = logging.getLogger(__name__) + + +def collate(samples): + try: + from pychain import ChainGraphBatch + except ImportError: + raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") + + if len(samples) == 0: + return {} + + def merge(key): + if key == "source": + return speech_utils.collate_frames([s[key] for s in samples], 0.0) + elif key == "target": + max_num_transitions = max(s["target"].num_transitions for s in samples) + max_num_states = max(s["target"].num_states for s in samples) + return ChainGraphBatch( + [s["target"] for s in samples], + max_num_transitions=max_num_transitions, + max_num_states=max_num_states, + ) + else: + raise ValueError("Invalid key.") + + id = torch.LongTensor([s["id"] for s in samples]) + src_frames = merge("source") + # sort by descending source length + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()] + src_frames = src_frames.index_select(0, sort_order) + ntokens = sum(s["source"].size(0) for s in samples) + + target = None + if samples[0].get("target", None) is not None: + target = merge("target") + target.reorder(sort_order) + + text = None + if samples[0].get("text", None) is not None: + text = [samples[i]["text"] for i in sort_order.numpy()] + + batch = { + "id": id, + "utt_id": utt_id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_frames, + "src_lengths": src_lengths, + }, + "target": target, + "text": text, + } + return batch + + +class NumeratorGraphDataset(FairseqDataset): + """ + A dataset of numerator graphs for LF-MMI. It loads all graphs into memory at + once as its relatively small. + """ + + def __init__(self, utt_ids: List[str], rxfiles: List[str]): + super().__init__() + self.read_fsts(utt_ids, rxfiles) + + def read_fsts(self, utt_ids: List[str], rxfiles: List[str]): + try: + from pychain import ChainGraph + import simplefst + except ImportError: + raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") + + self.utt_ids = [] + self.rxfiles = [] + self.size = 0 # number of utterances + self.sizes = [] # num of states in each fst + self.numerator_graphs = [] + for i, rxfile in enumerate(rxfiles): + file_path, offset = self._parse_rxfile(rxfile) + fst = simplefst.StdVectorFst.read_ark(file_path, offset) + graph = ChainGraph(fst, leaky_mode="uniform") + if not graph.is_empty: # skip empty graphs + self.utt_ids.append(utt_ids[i]) + self.rxfiles.append(rxfile) + self.size += 1 + self.sizes.append(fst.num_states()) + self.numerator_graphs.append(graph) + self.sizes = np.array(self.sizes, dtype=np.int32) + + def _parse_rxfile(self, rxfile): + # separate offset from filename + m = re.match(r"(\S+):([0-9]+)", rxfile) + assert m is not None, "Illegal rxfile: {}".format(rxfile) + return m.group(1), int(m.group(2)) + + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError("index out of range") + + def filter_and_reorder(self, indices): + assert isinstance(indices, (list, np.ndarray)) + indices = np.array(indices) + assert all(indices < len(self.utt_ids)) and all(indices >= 0) + assert len(np.unique(indices)) == len(indices), \ + "Duplicate elements in indices." + self.utt_ids = [self.utt_ids[i] for i in indices] + self.rxfiles = [self.rxfiles[i] for i in indices] + self.numerator_graphs = [self.numerator_graphs[i] for i in indices] + self.sizes = self.sizes[indices] + self.size = len(self.utt_ids) + + def __getitem__(self, i): + self.check_index(i) + return self.numerator_graphs[i] + + def __len__(self): + return self.size + + @staticmethod + def exists(path): + return os.path.exists(path) + + +class AsrChainDataset(FairseqDataset): + """ + A pair of torch.utils.data.Datasets. + + Args: + src (torch.utils.data.Dataset): source dataset to wrap + src_sizes (List[int]): source sentence lengths + tgt (espresso.data.NumeratorGraphDataset, optional): target numerator graph dataset to wrap + tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) + text (torch.utils.data.Dataset, optional): text dataset to wrap + max_source_positions (int, optional): max number of frames in the + source (default: 1024). + max_target_positions (int, optional): max number of tokens in the target + sentence (default: 1024) + shuffle (bool, optional): shuffle dataset elements before batching + (default: True) + """ + + def __init__( + self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, + max_source_positions=1024, max_target_positions=1024, shuffle=True, + ): + self.src = src + self.tgt = tgt + self.src_sizes = np.array(src_sizes) + self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + self.text = text + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.shuffle = shuffle + self.epoch = 1 + num_before_matching = len(self.src.utt_ids) + if self.tgt is not None: + self._match_src_tgt() + if self.text is not None: + changed = self._match_src_text() + if self.tgt is not None and changed: + self._match_src_tgt() + num_after_matching = len(self.src.utt_ids) + num_removed = num_before_matching - num_after_matching + if num_removed > 0: + logger.warning( + "Removed {} examples due to empty numerator graphs or missing entries, " + "{} remaining".format(num_removed, num_after_matching) + ) + + def _match_src_tgt(self): + """Makes utterances in src and tgt the same order in terms of + their utt_ids. Removes those that are only present in one of them.""" + assert self.tgt is not None + if self.src.utt_ids == self.tgt.utt_ids: + return + tgt_utt_ids_set = set(self.tgt.utt_ids) + src_indices = [ + i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set + ] + self.src.filter_and_reorder(src_indices) + self.src_sizes = np.array(self.src.sizes) + try: + tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) + except ValueError: + raise ValueError( + "Unable to find some utt_id(s) in tgt. which is unlikely to happen. " + "Something must be wrong." + ) + self.tgt.filter_and_reorder(tgt_indices) + self.tgt_sizes = np.array(self.tgt.sizes) + assert self.src.utt_ids == self.tgt.utt_ids + + def _match_src_text(self): + """Makes utterances in src and text the same order in terms of + their utt_ids. Removes those that are only present in one of them.""" + assert self.text is not None + if self.src.utt_ids == self.text.utt_ids: + return False + text_utt_ids_set = set(self.text.utt_ids) + src_indices = [ + i for i, id in enumerate(self.src.utt_ids) if id in text_utt_ids_set + ] + self.src.filter_and_reorder(src_indices) + self.src_sizes = np.array(self.src.sizes) + try: + text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids)) + except ValueError: + raise ValueError( + "Unable to find some utt_id(s) in text. which is unlikely to happen. " + "Something must be wrong." + ) + self.text.filter_and_reorder(text_indices) + assert self.src.utt_ids == self.text.utt_ids + return True + + def __getitem__(self, index): + tgt_item = self.tgt[index] if self.tgt is not None else None + text_item = self.text[index][1] if self.text is not None else None + src_item = self.src[index] + example = { + "id": index, + "utt_id": self.src.utt_ids[index], + "source": src_item, + "target": tgt_item, + "text": text_item, + } + return example + + def __len__(self): + return len(self.src) + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `utt_id` (List[str]): list of utterance ids + - `nsentences` (int): batch size + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (FloatTensor): a padded 3D Tensor of features in + the source of shape `(bsz, src_len, feat_dim)`. + - `src_lengths` (IntTensor): 1D Tensor of the unpadded + lengths of each source sequence of shape `(bsz)` + + - `target` (ChainGraphBatch): an instance representing a batch of + numerator graphs + - `text` (List[str]): list of original text + """ + return collate(samples) + + def num_tokens(self, index): + """Return the number of frames in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.src_sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + + @property + def supports_prefetch(self): + return getattr(self.src, "supports_prefetch", False) + + def prefetch(self, indices): + """Only prefetch src.""" + self.src.prefetch(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + self.epoch = epoch + if hasattr(self.src, "set_epoch"): + self.src.set_epoch(epoch) + if self.tgt is not None and hasattr(self.tgt, "set_epoch"): + self.tgt.set_epoch(epoch) diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py new file mode 100644 index 000000000..60d573daf --- /dev/null +++ b/espresso/data/asr_xent_dataset.py @@ -0,0 +1,517 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from typing import List, Optional + +import numpy as np + +import torch +import torch.nn.functional as F + +from fairseq.data import data_utils, FairseqDataset + +import espresso.tools.utils as speech_utils + +try: + import kaldi_io +except ImportError: + raise ImportError("Please install kaldi_io with: pip install kaldi_io") + + +logger = logging.getLogger(__name__) + + +def collate( + samples, pad_idx, chunk_width, chunk_left_context, chunk_right_context, label_delay, + seed, epoch, random_chunking=True, +): + if len(samples) == 0: + return {} + + def merge(key): + if key == "source": + return speech_utils.collate_frames([s[key] for s in samples], 0.0) + elif key == "target": + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx=pad_idx, eos_idx=None, + left_pad=False, move_eos_to_beginning=False, + ) + else: + raise ValueError("Invalid key.") + + def chunking(src_item, tgt_item, tgt_start): + # make a src chunk in the range [begin_src, end_src) + begin_src = max(0, tgt_start + label_delay - chunk_left_context) + # ok if end_src past the end of utterance + end_src = tgt_start + label_delay + chunk_width + chunk_right_context + # replication pad if necessary + left_pad = max(0, chunk_left_context - tgt_start - label_delay) + right_pad = max(0, end_src - src_item.size(0)) + src_item = src_item[begin_src: end_src] + if left_pad > 0 or right_pad > 0: + src_item = F.pad( + src_item.t().unsqueeze(0), (left_pad, right_pad), mode="replicate", + ).squeeze(0).t() + + if tgt_item is not None: + # make a tgt chunk in the range [begin_tgt, end_tgt) + begin_tgt = tgt_start + end_tgt = tgt_start + chunk_width # ok if past the end of utterance + # replication pad if necessary + right_pad = max(0, end_tgt - tgt_item.size(0)) + tgt_item = tgt_item[begin_tgt: end_tgt] + if right_pad > 0: + tgt_item = torch.cat( + (tgt_item, tgt_item.new_full((right_pad,), pad_idx)), 0 + ) + return src_item, tgt_item + + if chunk_width is None or random_chunking: + if chunk_width is not None: # usually for chunk-wise train data + # no need to sort as all chunks have exactly the same length + for s in samples: + with data_utils.numpy_seed(seed, epoch, s["id"]): + # generate a chunk by sampling the index of its first label + f = np.random.randint(s["source"].size(0) - chunk_width + 1) + s["source"], s["target"] = chunking(s["source"], s["target"], f) + elif label_delay != 0: # shift source according to label_delay + if label_delay > 0: + left_pad, right_pad = 0, label_delay + else: + left_pad, right_pad = -label_delay, 0 + for s in samples: + src_item = s["source"] + src_item = F.pad( + src_item.t().unsqueeze(0), (left_pad, right_pad), mode="replicate", + ).squeeze(0).t() + if label_delay > 0: + s["source"] = src_item[label_delay:] + else: + s["source"] = src_item[: label_delay] + + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + id = torch.LongTensor([s["id"] for s in samples]) + utt_id = [s["utt_id"] for s in samples] + src_frames = merge("source") + + target = None + if samples[0].get("target", None) is not None: + target = merge("target") + ntokens = sum(len(s["target"]) for s in samples) + else: + ntokens = sum(s["source"].size(0) for s in samples) + + text = None + if samples[0].get("text", None) is not None: + text = [s["text"] for s in samples] + + if chunk_width is None: # for whole utterances (i.e., no chunking) + # sort by descending source length + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + utt_id = [utt_id[i] for i in sort_order.numpy()] + src_frames = src_frames.index_select(0, sort_order) + if target is not None: + target = target.index_select(0, sort_order) + if text is not None: + text = [text[i] for i in sort_order.numpy()] + + batch = { + "id": id, + "utt_id": utt_id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_frames, + "src_lengths": src_lengths, + }, + "target": target, + "text": text, + } + return batch + else: # sequential chunking, usually for chunk-wise test data + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + id = torch.LongTensor([s["id"] for s in samples]) + utt_id = [s["utt_id"] for s in samples] + ori_source = [s["source"] for s in samples] + ori_target = [s["target"] for s in samples] + text = None + if samples[0].get("text", None) is not None: + text = [s["text"] for s in samples] + max_length = max(src.size(0) for src in ori_source) + num_chunks = (max_length + chunk_width - 1) // chunk_width + batches = [] + for k in range(num_chunks): + f = k * chunk_width + for i, s in enumerate(samples): + if f < src_lengths[i].item(): + s["source"], s["target"] = chunking(ori_source[i], ori_target[i], f) + else: + s["source"] = ori_source[i].new_zeros( + chunk_width + chunk_left_context + chunk_right_context, ori_source[i].size(1) + ) + s["target"] = ori_target[i].new_full((chunk_width,), pad_idx) \ + if ori_target[i] is not None else None + src_frames = merge("source") + src_chunk_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + + target = None + if samples[0].get("target", None) is not None: + target = merge("target") + ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) + else: + ntokens = sum(s["source"].size(0) for s in samples) + + batch = { + "id": id, + "utt_id": utt_id, + "nsentences": len(samples) if k == 0 else 0, + "ntokens": ntokens, + "net_input": { + "src_tokens": src_frames, + "src_lengths": src_chunk_lengths, + }, + "target": target, + "text": text, + } + batches.append(batch) + return batches + + +class AliScpCachedDataset(torch.utils.data.Dataset): + """ + A dataset for alignments prepared in Kaldi scp format (e.g., ali.scp). + This class loads a batch of feature matrices (specified as *cache_size*) + every time an entry is inquired. The inquire order should be known in advance. + It balances the I/O efficiency and memory usage. + """ + + def __init__( + self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, + ordered_prefetch=False, cache_size=327680, + ): + super().__init__() + assert len(utt_ids) == len(rxfiles) + self.dtype = np.int16 + self.utt_ids = utt_ids + self.rxfiles = rxfiles + self.size = len(utt_ids) # number of utterances + self.sizes = [] # length of each utterance + if utt2num_frames is not None and len(utt2num_frames) > 0: + assert len(utt2num_frames) == self.size + self.sizes = utt2num_frames + + if len(self.sizes) == 0: + for rxfile in self.rxfiles: + try: + ali = kaldi_io.read_vec_int(rxfile) + except Exception: + raise Exception("failed to read int vector {}.".format(rxfile)) + assert ali is not None and isinstance(ali, np.ndarray) + self.sizes.append(ali.shape[0]) + + assert len(self.sizes) == self.size + self.sizes = np.array(self.sizes, dtype=np.int32) + + self.cache = None + self.cache_index = {} + self.cache_size = cache_size # in terms of number of examples + self.start_pos_for_next_cache = 0 + self.ordered_indices = list(range(self.size)) + # set to True ONLY if examples are queried in the same order as + # self.ordered_indices, and doing this will speed up search of the + # queried index + self.ordered_prefetch = ordered_prefetch + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + """Sets self.ordered_indices. If being called, the caller is supposed to + query examples in the same order as self.ordered_indices. + self.ordered_prefetch can be set to True in this case. Note: the purpose + of this function is different from what it is supposed to do in the + fairseq framework.""" + assert isinstance(indices, (list, np.ndarray)) + assert self.size >= len(indices) + self.ordered_indices = indices.copy() + self.start_pos_for_next_cache = 0 + + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError("index out of range") + + def filter_and_reorder(self, indices): + assert isinstance(indices, (list, np.ndarray)) + indices = np.array(indices) + assert all(indices < len(self.utt_ids)) and all(indices >= 0) + assert len(np.unique(indices)) == len(indices), \ + "Duplicate elements in indices." + self.utt_ids = [self.utt_ids[i] for i in indices] + self.rxfiles = [self.rxfiles[i] for i in indices] + self.sizes = self.sizes[indices] + self.size = len(self.utt_ids) + self.ordered_indices = list(range(self.size)) + + def __getitem__(self, i): + self.check_index(i) + if i not in self.cache_index: + assert self.start_pos_for_next_cache < \ + len(self.ordered_indices), \ + "Position for next cache starting beyond the end of ordered_indices." + try: + pos_start = self.ordered_indices.index( + i, self.start_pos_for_next_cache, + ) + except ValueError: + raise ValueError( + "index {} not found in self.ordered_indices. Set " + "self.ordered_prefetch to False, and/or call self.prefetch() " + "with the full list of indices, and then try again.".format(i) + ) + pos_end = min( + pos_start + self.cache_size, len(self.ordered_indices), + ) + self.start_pos_for_next_cache = pos_end \ + if self.ordered_prefetch else 0 + total_size = 0 + for idx in self.ordered_indices[pos_start: pos_end]: + total_size += self.sizes[idx] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for idx in self.ordered_indices[pos_start: pos_end]: + self.cache_index[idx] = ptx + length = self.sizes[idx] + dst = self.cache[ptx: ptx + length] + np.copyto(dst, kaldi_io.read_vec_int(self.rxfiles[idx])) + ptx += length + + ptx = self.cache_index[i] + a = self.cache[ptx: ptx + self.sizes[i]].copy() + return torch.from_numpy(a).long() + + def __len__(self): + return self.size + + @staticmethod + def exists(path): + return os.path.exists(path) + + +class AsrXentDataset(FairseqDataset): + """ + A pair of torch.utils.data.Datasets. + + Args: + src (torch.utils.data.Dataset): source dataset to wrap + src_sizes (List[int]): source sentence lengths + tgt (espresso.data.AliScpCachedDataset, optional): target alignment dataset to wrap + tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) + tgt_vocab_size (int, optional): used for setting padding index + text (torch.utils.data.Dataset, optional): text dataset to wrap + max_source_positions (int, optional): max number of frames in the + source (default: 1024). + max_target_positions (int, optional): max number of tokens in the target + sentence (default: 1024) + shuffle (bool, optional): shuffle dataset elements before batching + (default: True) + seed (int, optional): random seed for generating a chunk from an utterance + chunk_width (int, optional): chunk width for chunk-wise training + chunk_left_context (int, optional): number of frames appended to the left of a chunk + chunk_right_context (int, optional): number of frames appended to the right of a chunk + label_delay (int, optional): offset of the alignments as prediction labels. Can be + useful in archs such as asymmetric convolution, unidirectional LSTM, etc + random_chunking (bool, optional): wether do random chunking from utterance, or sequntially + obtain chunks within each utterance. True for train and False for valid/test data + """ + + def __init__( + self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, + max_source_positions=1024, max_target_positions=1024, shuffle=True, + seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, + label_delay=0, random_chunking=True, + ): + self.src = src + self.tgt = tgt + self.src_sizes = np.array(src_sizes) + self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + self.text = text + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.shuffle = shuffle + self.seed = seed + self.epoch = 1 + assert chunk_width is None or chunk_width > 0 + self.chunk_width = chunk_width + assert chunk_left_context >= 0 and chunk_right_context >= 0 + self.chunk_left_context = chunk_left_context + self.chunk_right_context = chunk_right_context + assert (label_delay < 0 and -label_delay <= chunk_right_context) or \ + (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) + self.label_delay = label_delay + self.random_chunking = random_chunking + if self.tgt is not None: + self._match_src_tgt() + if self.text is not None: + changed = self._match_src_text() + if self.tgt is not None and changed: + self._match_src_tgt() + + if chunk_width is not None: + # remove those whose lengths are shorter than chunk_size + indices = np.flatnonzero(self.src.sizes >= chunk_width) + if len(indices) < self.src.size: + logger.warning( + "Removing {} examples whose lengths are shorter than chunk_size={}".format( + self.src.size - len(indices), chunk_width + ) + ) + self.src.filter_and_reorder(indices) + if self.tgt is not None: + self.tgt.filter_and_reorder(indices) + if self.text is not None: + self.text.filter_and_reorder(indices) + logger.warning("Done removal. {} examples remaining".format(len(indices))) + + def _match_src_tgt(self): + """Makes utterances in src and tgt the same order in terms of + their utt_ids. Removes those that are only present in one of them.""" + assert self.tgt is not None + if self.src.utt_ids == self.tgt.utt_ids: + assert np.all(self.src.sizes == self.tgt.sizes), "frame and alignment lengths mismatch" + return + tgt_utt_ids_set = set(self.tgt.utt_ids) + src_indices = [ + i for i, id in enumerate(self.src.utt_ids) if id in tgt_utt_ids_set + ] + self.src.filter_and_reorder(src_indices) + self.src_sizes = np.array(self.src.sizes) + try: + tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) + except ValueError: + raise ValueError( + "Unable to find some utt_id(s) in tgt. which is unlikely to happen. " + "Something must be wrong." + ) + self.tgt.filter_and_reorder(tgt_indices) + self.tgt_sizes = np.array(self.tgt.sizes) + assert self.src.utt_ids == self.tgt.utt_ids + assert np.all(self.src.sizes == self.tgt.sizes), "frame and alignment lengths mismatch" + + def _match_src_text(self): + """Makes utterances in src and text the same order in terms of + their utt_ids. Removes those that are only present in one of them.""" + assert self.text is not None + if self.src.utt_ids == self.text.utt_ids: + return False + text_utt_ids_set = set(self.text.utt_ids) + src_indices = [ + i for i, id in enumerate(self.src.utt_ids) if id in text_utt_ids_set + ] + self.src.filter_and_reorder(src_indices) + self.src_sizes = np.array(self.src.sizes) + try: + text_indices = list(map(self.text.utt_ids.index, self.src.utt_ids)) + except ValueError: + raise ValueError( + "Unable to find some utt_id(s) in text. which is unlikely to happen. " + "Something must be wrong." + ) + self.text.filter_and_reorder(text_indices) + assert self.src.utt_ids == self.text.utt_ids + return True + + def __getitem__(self, index): + tgt_item = self.tgt[index] if self.tgt is not None else None + text_item = self.text[index][1] if self.text is not None else None + src_item = self.src[index] + example = { + "id": index, + "utt_id": self.src.utt_ids[index], + "source": src_item, + "target": tgt_item, + "text": text_item, + } + return example + + def __len__(self): + return len(self.src) + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `utt_id` (List[str]): list of utterance ids + - `nsentences` (int): batch size + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (FloatTensor): a padded 3D Tensor of features in + the source of shape `(bsz, src_len, feat_dim)`. + - `src_lengths` (IntTensor): 1D Tensor of the unpadded + lengths of each source sequence of shape `(bsz)` + + - `target` (LongTensor): a padded 2D Tensor of indices in the + target alignments of shape `(bsz, tgt_len)` + - `text` (List[str]): list of original text + """ + # pad_idx=-100 matches the default in criterions + return collate( + samples, pad_idx=-100, chunk_width=self.chunk_width, + chunk_left_context=self.chunk_left_context, chunk_right_context=self.chunk_right_context, + label_delay=self.label_delay, seed=self.seed, epoch=self.epoch, + random_chunking=self.random_chunking, + ) + + def num_tokens(self, index): + """Return the number of frames in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + if self.chunk_width is None: + return self.src_sizes[index] + return self.chunk_width + self.chunk_left_context + self.chunk_right_context + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + + @property + def supports_prefetch(self): + return getattr(self.src, "supports_prefetch", False) + + def prefetch(self, indices): + self.src.prefetch(indices) + if self.tgt is not None: + self.tgt.prefetch(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + self.epoch = epoch + if hasattr(self.src, "set_epoch"): + self.src.set_epoch(epoch) + if self.tgt is not None and hasattr(self.tgt, "set_epoch"): + self.tgt.set_epoch(epoch) diff --git a/espresso/data/feat_text_dataset.py b/espresso/data/feat_text_dataset.py index eb70e863f..c4c49383b 100644 --- a/espresso/data/feat_text_dataset.py +++ b/espresso/data/feat_text_dataset.py @@ -11,6 +11,7 @@ import torch from fairseq.data import data_utils +from fairseq.tokenizer import tokenize_line from espresso.tools.specaug_interpolate import specaug @@ -225,30 +226,33 @@ class AsrTextDataset(torch.utils.data.Dataset): Original lines are also kept in memory. Each line of the text file is in the format of 'utt_id tokenized_text'.""" - def __init__(self, utt_ids: List[str], token_text: List[str], dictionary, append_eos=True): + def __init__(self, utt_ids: List[str], token_text: List[str], dictionary=None, append_eos=True): super().__init__() self.dtype = np.float self.append_eos = append_eos self.read_text(utt_ids, token_text, dictionary) - def read_text(self, utt_ids: List[str], token_text: List[str], dictionary): + def read_text(self, utt_ids: List[str], token_text: List[str], dictionary=None): assert len(utt_ids) == len(token_text) self.utt_ids = utt_ids self.tokens_list = token_text self.tensor_list = [] self.size = len(self.utt_ids) # number of utterances self.sizes = [] - for tokens in self.tokens_list: - tensor = dictionary.encode_line( - tokens, add_if_not_exist=False, append_eos=self.append_eos, - ).long() - self.tensor_list.append(tensor) - self.sizes.append(len(self.tensor_list[-1])) + if dictionary is not None: + for tokens in self.tokens_list: + tensor = dictionary.encode_line( + tokens, add_if_not_exist=False, append_eos=self.append_eos, + ).long() + self.tensor_list.append(tensor) + self.sizes.append(len(self.tensor_list[-1])) + else: + self.sizes = [len(tokenize_line(tokens)) for tokens in self.tokens_list] self.sizes = np.array(self.sizes, dtype=np.int32) assert len(self.utt_ids) == len(self.tokens_list) and \ - len(self.utt_ids) == len(self.tensor_list) and \ + (dictionary is None or len(self.utt_ids) == len(self.tensor_list)) and \ len(self.utt_ids) == len(self.sizes) def check_index(self, i): @@ -263,13 +267,14 @@ def filter_and_reorder(self, indices): 'Duplicate elements in indices.' self.utt_ids = [self.utt_ids[i] for i in indices] self.tokens_list = [self.tokens_list[i] for i in indices] - self.tensor_list = [self.tensor_list[i] for i in indices] + if len(self.tensor_list) > 0: + self.tensor_list = [self.tensor_list[i] for i in indices] self.sizes = self.sizes[indices] self.size = len(self.utt_ids) def __getitem__(self, i): self.check_index(i) - return self.tensor_list[i], self.tokens_list[i] + return self.tensor_list[i] if len(self.tensor_list) > 0 else None, self.tokens_list[i] def __len__(self): return self.size diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py new file mode 100755 index 000000000..a3ee532ac --- /dev/null +++ b/espresso/dump_posteriors.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Dump frame-level posteriors (intepreted as log probabilities) with a trained model +for decoding with Kaldi. +""" + +import logging +import sys + +import torch + +from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter + +try: + import kaldi_io +except ImportError: + raise ImportError("Please install kaldi_io with: pip install kaldi_io") + + +def main(args): + assert args.path is not None, "--path required for decoding!" + return _main(args, sys.stderr) + + +def _main(args, output_file): + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + stream=output_file, + ) + logger = logging.getLogger("espresso.dump_posteriors") + + print_options_meaning_changes(args, logger) + + utils.import_user_module(args) + + if args.max_tokens is None and args.max_sentences is None: + args.max_tokens = 12000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset split + task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) + + # Load ensemble + logger.info("loading model(s) from {}".format(args.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + utils.split_paths(args.path), + arg_overrides=eval(args.model_overrides), + task=task, + ) + + # Load state prior for cross-entropy trained systems decoding + if args.state_prior_file is not None: + prior = torch.from_numpy(kaldi_io.read_vec_flt(args.state_prior_file)) + else: + prior = [] + + # Optimize ensemble for generation + for model in models: + model.make_generation_fast_() + if args.fp16: + model.half() + if use_cuda: + model.cuda() + if isinstance(prior, list) and getattr(model, "state_prior", None) is not None: + prior.append(model.state_prior.unsqueeze(0)) + + if isinstance(prior, list) and len(prior) > 0: + prior = torch.cat(prior, 0).mean(0) # average priors across models + prior = prior / prior.sum() # re-normalize + elif isinstance(prior, list): + prior = None + + if prior is not None: + if args.fp16: + prior = prior.half() + if use_cuda: + prior = prior.cuda() + log_prior = prior.log() + else: + log_prior = None + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.max_sentences, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[model.max_positions() if hasattr(model, "encoder") + else (None, model.max_positions()) for model in models] + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + default_log_format=("tqdm" if not args.no_progress_bar else "none"), + ) + + # Initialize generator + gen_timer = StopwatchMeter() + generator = task.build_generator(models, args) + + # Generate and dump + num_sentences = 0 + chunk_width = getattr(task, "chunk_width", None) + lprobs_wspecifier = "ark:| copy-matrix ark:- ark:-" + with kaldi_io.open_or_fd(lprobs_wspecifier, "wb") as f: + if chunk_width is None: # normal dumping (i.e., no chunking) + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + gen_timer.start() + lprobs, padding_mask = task.inference_step(generator, models, sample) + if log_prior is not None: + assert lprobs.size(-1) == log_prior.size(0) + lprobs = lprobs - log_prior + out_lengths = (~padding_mask).long().sum(dim=1).cpu() if padding_mask is not None else None + num_processed_frames = sample["ntokens"] + gen_timer.stop(num_processed_frames) + num_sentences += sample["nsentences"] + + if out_lengths is not None: + for i in range(sample["nsentences"]): + length = out_lengths[i] + kaldi_io.write_mat(f, lprobs[i, :length, :].cpu().numpy(), key=sample["utt_id"][i]) + else: + for i in range(sample["nsentences"]): + kaldi_io.write_mat(f, lprobs[i, :, :].cpu().numpy(), key=sample["utt_id"][i]) + else: # dumping chunks within the same utterance from left to right + for sample in progress: # sample is actually a list of batches + sample = utils.move_to_cuda(sample) if use_cuda else sample + utt_id = sample[0]["utt_id"] + id = sample[0]["id"] + whole_lprobs = None + for i, chunk_sample in enumerate(sample): + if "net_input" not in chunk_sample: + continue + + assert chunk_sample["utt_id"] == utt_id and (chunk_sample["id"] == id).all() + gen_timer.start() + lprobs, _ = task.inference_step(generator, models, chunk_sample) + if log_prior is not None: + assert lprobs.size(-1) == log_prior.size(0) + lprobs = lprobs - log_prior + if whole_lprobs is None: + whole_lprobs = lprobs.cpu() + else: + whole_lprobs = torch.cat((whole_lprobs, lprobs.cpu()), 1) + num_processed_frames = chunk_sample["ntokens"] + gen_timer.stop(num_processed_frames) + + if i == len(sample) - 1: + num_sentences += len(utt_id) + for j in range(len(utt_id)): + truncated_length = models[0].output_lengths( + task.dataset(args.gen_subset).src_sizes[id[j]] + ) # length is after possible subsampling by the model + mat = whole_lprobs[j, :truncated_length, :] + kaldi_io.write_mat(f, mat.numpy(), key=utt_id[j]) + + logger.info("Dumped {} utterances ({} frames) in {:.1f}s ({:.2f} sentences/s, {:.2f} frames/s)".format( + num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) + + return + + +def print_options_meaning_changes(args, logger): + """Options that have different meanings than those in the translation task + are explained here. + """ + logger.info("| --max-tokens is the maximum number of input frames in a batch") + + +def cli_main(): + parser = options.get_generation_parser(default_task="speech_recognition_hybrid") + parser.add_argument("--apply-log-softmax", action="store_true", + help="Apply log-softmax to the neural network outputs for some " + "systems, e.g., Xent. Otherwise use the raw outputs") + parser.add_argument("--state-prior-file", default=None, type=str, metavar="FILE", + help="state prior file. If provided, use this file instead of " + "that from the checkpoint") + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 815070cad..71ba0d371 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import torch from torch import Tensor @@ -414,9 +414,11 @@ def forward(self, src_tokens, src_lengths: Tensor, **unused): ) def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(1, new_order) \ + if encoder_out.encoder_padding_mask is not None else None return EncoderOut( encoder_out=encoder_out.encoder_out.index_select(1, new_order), - encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(1, new_order) if encoder_out.encoder_padding_mask is not None else None, + encoder_padding_mask=encoder_padding_mask, encoder_embedding=None, encoder_states=None, src_tokens=None, diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py new file mode 100644 index 000000000..ef1bc10c5 --- /dev/null +++ b/espresso/models/speech_lstm_encoder_model.py @@ -0,0 +1,229 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +from torch import Tensor +import torch.nn.functional as F + +from fairseq import options +from fairseq.models import ( + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.lstm import Linear + +from espresso.models.speech_lstm import ConvBNReLU, SpeechLSTMEncoder +import espresso.tools.utils as speech_utils + + +DEFAULT_MAX_SOURCE_POSITIONS = 1e5 + + +logger = logging.getLogger(__name__) + + +@register_model("speech_lstm_encoder_model") +class SpeechLSTMEncoderModel(FairseqEncoderModel): + def __init__(self, encoder, state_prior: Optional[torch.FloatTensor] = None): + super().__init__(encoder) + self.state_prior = state_prior + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--encoder-conv-channels", type=str, metavar="EXPR", + help="list of encoder convolution's out channels") + parser.add_argument("--encoder-conv-kernel-sizes", type=str, metavar="EXPR", + help="list of encoder convolution's kernel sizes") + parser.add_argument("--encoder-conv-strides", type=str, metavar="EXPR", + help="list of encoder convolution's strides") + parser.add_argument("--encoder-rnn-hidden-size", type=int, metavar="N", + help="encoder rnn's hidden size") + parser.add_argument("--encoder-rnn-layers", type=int, metavar="N", + help="number of rnn encoder layers") + parser.add_argument("--encoder-rnn-bidirectional", + type=lambda x: options.eval_bool(x), + help="make all rnn layers of encoder bidirectional") + parser.add_argument("--encoder-rnn-residual", + type=lambda x: options.eval_bool(x), + help="create residual connections for rnn encoder " + "layers (starting from the 2nd layer), i.e., the actual " + "output of such layer is the sum of its input and output") + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument("--encoder-rnn-dropout-in", type=float, metavar="D", + help="dropout probability for encoder rnn's input") + parser.add_argument("--encoder-rnn-dropout-out", type=float, metavar="D", + help="dropout probability for encoder rnn's output") + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + max_source_positions = getattr(args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS) + + out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + logger.info("input feature dimension: {}, channels: {}".format(task.feat_dim, task.feat_in_channels)) + assert task.feat_dim % task.feat_in_channels == 0 + conv_layers = ConvBNReLU( + out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, + ) if out_channels is not None else None + + rnn_encoder_input_size = task.feat_dim // task.feat_in_channels + if conv_layers is not None: + for stride in strides: + if isinstance(stride, (list, tuple)): + assert len(stride) > 0 + s = stride[1] if len(stride) > 1 else stride[0] + else: + assert isinstance(stride, int) + s = stride + rnn_encoder_input_size = (rnn_encoder_input_size + s - 1) // s + rnn_encoder_input_size *= out_channels[-1] + else: + rnn_encoder_input_size = task.feat_dim + + encoder = SpeechChunkLSTMEncoder( + conv_layers_before=conv_layers, + input_size=rnn_encoder_input_size, + hidden_size=args.encoder_rnn_hidden_size, + num_layers=args.encoder_rnn_layers, + dropout_in=args.encoder_rnn_dropout_in, + dropout_out=args.encoder_rnn_dropout_out, + bidirectional=args.encoder_rnn_bidirectional, + residual=args.encoder_rnn_residual, + num_targets=getattr(task, "num_targets", None), # targets for encoder-only model + chunk_width=getattr(task, "chunk_width", None), + chunk_left_context=getattr(task, "chunk_left_context", 0), + training_stage=getattr(task, "training_stage", True), + max_source_positions=max_source_positions, + ) + return cls(encoder, state_prior=getattr(task, "initial_state_prior", None)) + + def output_lengths(self, in_lengths): + return self.encoder.output_lengths(in_lengths) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + encoder_out = net_output.encoder_out + if torch.is_tensor(encoder_out): + logits = encoder_out.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def update_state_prior(self, new_state_prior, factor=0.1): + assert self.state_prior is not None + self.state_prior = self.state_prior.to(new_state_prior) + self.state_prior = (1. - factor) * self.state_prior + factor * new_state_prior + self.state_prior = self.state_prior / self.state_prior.sum() # re-normalize + + def state_dict(self): + state_dict = super().state_dict() + state_dict["state_prior"] = self.state_prior + return state_dict + + def load_state_dict(self, state_dict, strict=True, args=None): + state_dict_subset = state_dict.copy() + self.state_prior = state_dict.get("state_prior", None) + if "state_prior" in state_dict: + self.state_prior = state_dict["state_prior"] + del state_dict_subset["state_prior"] + super().load_state_dict(state_dict_subset, strict=strict, args=args) + + +class SpeechChunkLSTMEncoder(SpeechLSTMEncoder): + """LSTM encoder.""" + def __init__( + self, conv_layers_before=None, input_size=83, hidden_size=512, + num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, + residual=False, left_pad=False, padding_value=0., + num_targets=None, chunk_width=20, chunk_left_context=0, training_stage=True, + max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, + ): + super().__init__( + conv_layers_before=conv_layers_before, input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, dropout_in=dropout_in, dropout_out=dropout_in, + bidirectional=bidirectional, residual=residual, left_pad=left_pad, + padding_value=padding_value, max_source_positions=max_source_positions, + ) + receptive_field_radius = sum(conv.padding[0] for conv in conv_layers_before.convolutions) \ + if conv_layers_before is not None else 0 + assert chunk_width is None or chunk_width > 0 + assert (conv_layers_before is None and chunk_left_context >= 0) or \ + (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 + self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ + if chunk_width is not None else None + self.training_stage = training_stage + + # only for encoder-only model + self.fc_out = Linear(self.output_units, num_targets, dropout=dropout_out) \ + if num_targets is not None else None + + def forward(self, src_tokens, src_lengths: Tensor, **unused): + out = super().forward(src_tokens, src_lengths, **unused) + x, encoder_padding_mask, x_lengths = out.encoder_out, out.encoder_padding_mask, out.src_lengths + + # determine which output frame to select for loss evaluation/test, assuming + # all examples in a batch are of the same length for chunk-wise training/test + if ( + self.out_chunk_end is not None + and (self.training or not self.training_stage) + ): + x = x[self.out_chunk_begin: self.out_chunk_end] # T x B x C -> W x B x C + x_lengths = x_lengths.fill_(x.size(0)) + assert encoder_padding_mask is None + + if self.fc_out is not None: + x = self.fc_out(x) # T x B x C -> T x B x V + + return EncoderOut( + encoder_out=x, # T x B x C + encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=x_lengths, # B + ) + + +@register_model_architecture("speech_lstm_encoder_model", "speech_lstm_encoder_model") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.4) + args.encoder_conv_channels = getattr( + args, "encoder_conv_channels", "[64, 64, 128, 128]", + ) + args.encoder_conv_kernel_sizes = getattr( + args, "encoder_conv_kernel_sizes", "[(3, 3), (3, 3), (3, 3), (3, 3)]", + ) + args.encoder_conv_strides = getattr( + args, "encoder_conv_strides", "[(1, 1), (2, 2), (1, 1), (2, 2)]", + ) + args.encoder_rnn_hidden_size = getattr(args, "encoder_rnn_hidden_size", 320) + args.encoder_rnn_layers = getattr(args, "encoder_rnn_layers", 3) + args.encoder_rnn_bidirectional = getattr(args, "encoder_rnn_bidirectional", True) + args.encoder_rnn_residual = getattr(args, "encoder_rnn_residual", False) + args.encoder_rnn_dropout_in = getattr(args, "encoder_rnn_dropout_in", args.dropout) + args.encoder_rnn_dropout_out = getattr(args, "encoder_rnn_dropout_out", args.dropout) + + +@register_model_architecture("speech_lstm_encoder_model", "speech_conv_lstm_encoder_model_wsj") +def encoder_conv_lstm_wsj(args): + base_architecture(args) diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py new file mode 100644 index 000000000..0412dd8e0 --- /dev/null +++ b/espresso/models/speech_tdnn.py @@ -0,0 +1,302 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.lstm import Linear + +import espresso.tools.utils as speech_utils + + +logger = logging.getLogger(__name__) + + +@register_model("speech_tdnn") +class SpeechTdnnEncoderModel(FairseqEncoderModel): + def __init__(self, encoder, state_prior: Optional[torch.FloatTensor] = None): + super().__init__(encoder) + self.num_updates = 0 + self.state_prior = state_prior + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--hidden-sizes", type=str, metavar="EXPR", + help="list of hidden sizes for all Tdnn layers") + parser.add_argument("--kernel-sizes", type=str, metavar="EXPR", + help="list of all Tdnn layer\'s kernel sizes") + parser.add_argument("--strides", type=str, metavar="EXPR", + help="list of all Tdnn layer\'s strides") + parser.add_argument("--dilations", type=str, metavar="EXPR", + help="list of all Tdnn layer\'s dilations") + parser.add_argument("--num-layers", type=int, metavar="N", + help="number of Tdnn layers") + parser.add_argument("--residual", type=lambda x: options.eval_bool(x), + help="create residual connections for rnn encoder " + "layers (starting from the 2nd layer), i.e., the actual " + "output of such layer is the sum of its input and output") + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument("--dropout-in", type=float, metavar="D", + help="dropout probability for encoder\'s input") + parser.add_argument("--dropout-out", type=float, metavar="D", + help="dropout probability for Tdnn layers\' output") + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + hidden_sizes = speech_utils.eval_str_nested_list_or_tuple(args.hidden_sizes, type=int) + kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.kernel_sizes, type=int) + strides = speech_utils.eval_str_nested_list_or_tuple(args.strides, type=int) + dilations = speech_utils.eval_str_nested_list_or_tuple(args.dilations, type=int) + logger.info("input feature dimension: {}, output dimension: {}".format(task.feat_dim, task.num_targets)) + + encoder = SpeechTdnnEncoder( + input_size=task.feat_dim, + output_size=task.num_targets, + hidden_sizes=hidden_sizes, + kernel_sizes=kernel_sizes, + strides=strides, + dilations=dilations, + num_layers=args.num_layers, + dropout_in=args.dropout_in, + dropout_out=args.dropout_out, + residual=args.residual, + chunk_width=getattr(task, "chunk_width", None), + chunk_left_context=getattr(task, "chunk_left_context", 0), + training_stage=getattr(task, "training_stage", True), + ) + return cls(encoder, state_prior=getattr(task, "initial_state_prior", None)) + + def set_num_updates(self, num_updates): + self.num_updates = num_updates + super().set_num_updates(num_updates) + + def output_lengths(self, in_lengths): + return self.encoder.output_lengths(in_lengths) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + encoder_out = net_output.encoder_out + if torch.is_tensor(encoder_out): + logits = encoder_out.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def get_logits(self, net_output): + logits = net_output.encoder_out.transpose(0, 1).squeeze(2) # T x B x 1 -> B x T + return logits + + def update_state_prior(self, new_state_prior, factor=0.1): + assert self.state_prior is not None + self.state_prior = self.state_prior.to(new_state_prior) + self.state_prior = (1. - factor) * self.state_prior + factor * new_state_prior + self.state_prior = self.state_prior / self.state_prior.sum() # re-normalize + + def state_dict(self): + state_dict = super().state_dict() + state_dict["state_prior"] = self.state_prior + return state_dict + + def load_state_dict(self, state_dict, strict=True, args=None): + state_dict_subset = state_dict.copy() + self.state_prior = state_dict.get("state_prior", None) + if "state_prior" in state_dict: + self.state_prior = state_dict["state_prior"] + del state_dict_subset["state_prior"] + super().load_state_dict(state_dict_subset, strict=strict, args=args) + + +class TdnnBNReLU(nn.Module): + """A block of Tdnn-BatchNorm-ReLU layers.""" + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = dilation * (kernel_size - 1) // 2 + self.tdnn = nn.Conv1d( + in_channels, out_channels, kernel_size, + stride=stride, padding=self.padding, dilation=dilation, + ) + self.bn = nn.BatchNorm1d(out_channels) + + def output_lengths(self, in_lengths): + out_lengths = ( + in_lengths + 2 * self.padding - self.dilation * (self.kernel_size - 1) + + self.stride - 1 + ) // self.stride + return out_lengths + + def forward(self, src, src_lengths): + x = src.transpose(1, 2).contiguous() # B x T x C -> B x C x T + x = F.relu(self.bn(self.tdnn(x))) + x = x.transpose(2, 1).contiguous() # B x C x T -> B x T x C + x_lengths = self.output_lengths(src_lengths) + padding_mask = ~speech_utils.sequence_mask(x_lengths, x.size(1)) + if padding_mask.any(): + x = x.masked_fill(padding_mask.unsqueeze(-1), 0.0) + + return x, x_lengths, padding_mask + + +class SpeechTdnnEncoder(FairseqEncoder): + """Tdnn encoder.""" + def __init__( + self, input_size, output_size, hidden_sizes=256, kernel_sizes=3, strides=1, + dilations=3, num_layers=1, dropout_in=0.0, dropout_out=0.0, residual=False, + chunk_width=None, chunk_left_context=0, training_stage=True, + ): + super().__init__(None) # no src dictionary + self.num_layers = num_layers + if isinstance(hidden_sizes, int): + hidden_sizes = [hidden_sizes] * num_layers + else: + assert len(hidden_sizes) == num_layers + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * num_layers + else: + assert len(kernel_sizes) == num_layers + if isinstance(strides, int): + strides = [strides] * num_layers + else: + assert len(strides) == num_layers + if isinstance(dilations, int): + dilations = [dilations] * num_layers + else: + assert len(dilations) == num_layers + self.dropout_in = dropout_in + self.dropout_out = dropout_out + self.residual = residual + + self.tdnn = nn.ModuleList([ + TdnnBNReLU( + in_channels=input_size if layer == 0 else hidden_sizes[layer - 1], + out_channels=hidden_sizes[layer], kernel_size=kernel_sizes[layer], + stride=strides[layer], dilation=dilations[layer], + ) + for layer in range(num_layers) + ]) + + receptive_field_radius = sum(l.padding for l in self.tdnn) + assert chunk_width is None or (chunk_width > 0 and chunk_left_context >= receptive_field_radius) + if ( + chunk_width is not None and chunk_width > 0 + and chunk_left_context > receptive_field_radius + ): + logger.warning("chunk_{{left,right}}_context can be reduced to {}".format(receptive_field_radius)) + self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 + self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ + if chunk_width is not None else None + self.training_stage = training_stage + + self.fc_out = Linear(hidden_sizes[-1], output_size, dropout=dropout_out) + + def output_lengths(self, in_lengths): + out_lengths = in_lengths + for layer in self.tdnn: + out_lengths = layer.output_lengths(out_lengths) + return out_lengths + + def forward(self, src_tokens, src_lengths: Tensor, **unused): + x, encoder_padding_mask, x_lengths = self.extract_features(src_tokens, src_lengths) + if ( + self.out_chunk_end is not None + and (self.training or not self.training_stage) + ): + # determine which output frame to select for loss evaluation/test, assuming + # all examples in a batch are of the same length for chunk-wise training/test + x = x[self.out_chunk_begin: self.out_chunk_end] # T x B x C -> W x B x C + x_lengths = x_lengths.fill_(x.size(0)) + assert not encoder_padding_mask.any() + x = self.output_layer(x) + + return EncoderOut( + encoder_out=x, # T x B x C + encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=x_lengths, # B + ) + + def extract_features(self, src_tokens, src_lengths, **unused): + x, x_lengths = src_tokens, src_lengths + x = F.dropout(x, p=self.dropout_in, training=self.training) + + for i in range(len(self.tdnn)): + if self.residual and i > 0: # residual connection starts from the 2nd layer + prev_x = x + # apply Tdnn + x, x_lengths, padding_mask = self.tdnn[i](x, x_lengths) + x = F.dropout(x, p=self.dropout_out, training=self.training) + x = x + prev_x if self.residual and i > 0 and x.size(1) == prev_x.size(1) else x + + x = x.transpose(0, 1) # B x T x C -> T x B x C + encoder_padding_mask = padding_mask.t() + + return x, encoder_padding_mask, x_lengths + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + return self.fc_out(features) # T x B x C -> T x B x V + + def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(1, new_order) \ + if encoder_out.encoder_padding_mask is not None else None + return EncoderOut( + encoder_out=encoder_out.encoder_out.index_select(1, new_order), + encoder_padding_mask=encoder_padding_mask, + encoder_embedding=None, + encoder_states=None, + src_tokens=None, + src_lengths=encoder_out.src_lengths.index_select(0, new_order), + ) + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return int(1e5) # an arbitrary large number + + +@register_model_architecture("speech_tdnn", "speech_tdnn") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.0) + args.hidden_sizes = getattr(args, "hidden_sizes", "640") + args.kernel_sizes = getattr(args, "kernel_sizes", "[5, 3, 3, 3, 3]") + args.strides = getattr(args, "strides", "1") + args.dilations = getattr(args, "dilations", "[1, 1, 1, 3, 3]") + args.num_layers = getattr(args, "num_layers", 5) + args.residual = getattr(args, "residual", False) + args.dropout_in = getattr(args, "dropout_in", args.dropout) + args.dropout_out = getattr(args, "dropout_out", args.dropout) + + +@register_model_architecture("speech_tdnn", "speech_tdnn_wsj") +def tdnn_wsj(args): + base_architecture(args) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index b381e6420..d05eb60ae 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -202,6 +202,10 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') + # update the state prior stored in the model for cross-entropy training + if hasattr(task, 'update_state_prior'): + task.update_state_prior(trainer.get_model()) + valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index f60e94de4..f5bd5a845 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -328,7 +328,7 @@ def build_generator(self, models, args): match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, - lm_weight = getattr(args, "lm_weight", 0.0), + lm_weight=getattr(args, "lm_weight", 0.0), eos_factor=getattr(args, "eos_factor", None), ) diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py new file mode 100644 index 000000000..dadd9404b --- /dev/null +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -0,0 +1,382 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +import itertools +import json +import logging +import os + +import torch + +from fairseq import utils +from fairseq.data import ConcatDataset + +from fairseq.tasks import FairseqTask, register_task + +from espresso.data import ( + AliScpCachedDataset, + AsrChainDataset, + AsrXentDataset, + AsrDictionary, + AsrTextDataset, + FeatScpCachedDataset, + NumeratorGraphDataset, +) + +try: + import kaldi_io +except ImportError: + raise ImportError("Please install kaldi_io with: pip install kaldi_io") + + +logger = logging.getLogger(__name__) + + +def get_asr_dataset_from_json( + data_path, split, dictionary, + combine, upsample_primary, + max_source_positions, max_target_positions, + lf_mmi=True, + seed=1, specaugment_config=None, + chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, +): + """ + Parse data json and create dataset. + See espresso/tools/asr_prep_json.py which pack json from raw files + Json example: + { + "011c0202": { + "feat": "data/train_si284_spe2e_hires/data/raw_mfcc_train_si284_spe2e_hires.1.ark:24847", + "numerator_fst": "exp/chain/e2e_bichar_tree_tied1a/fst.1.ark:6704", + "alignment": "exp/tri3/ali.ark:8769", + "text": "THE HOTELi OPERATOR'S EMBASSY", + "utt2num_frames": "693", + }, + "011c0203": { + ... + } + } + """ + src_datasets = [] + tgt_datasets = [] + text_datasets = [] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + data_json_path = os.path.join(data_path, "{}.json".format(split_k)) + if not os.path.isfile(data_json_path): + if k > 0: + break + else: + raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + + with open(data_json_path, "rb") as f: + loaded_json = json.load(f, object_pairs_hook=OrderedDict) + + utt_ids, feats, numerator_fsts, alignments, text, utt2num_frames = [], [], [], [], [], [] + for utt_id, val in loaded_json.items(): + utt_ids.append(utt_id) + feats.append(val["feat"]) + if "numerator_fst" in val: + numerator_fsts.append(val["numerator_fst"]) + if "alignment" in val: + alignments.append(val["alignment"]) + if "text" in val: + text.append(val["text"]) + if "utt2num_frames" in val: + utt2num_frames.append(int(val["utt2num_frames"])) + + assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames) + src_datasets.append(FeatScpCachedDataset( + utt_ids, feats, utt2num_frames=utt2num_frames, seed=seed, + specaugment_config=specaugment_config if split == "train" else None, + ordered_prefetch=True, + )) + if lf_mmi: + if len(numerator_fsts) > 0: + assert len(utt_ids) == len(numerator_fsts) + tgt_datasets.append(NumeratorGraphDataset(utt_ids, numerator_fsts)) + else: # cross-entropy + if len(alignments) > 0: + assert len(utt_ids) == len(alignments) + tgt_datasets.append(AliScpCachedDataset( + utt_ids, alignments, utt2num_frames=utt2num_frames, ordered_prefetch=True + )) + + if len(text) > 0: + assert len(utt_ids) == len(text) + text_datasets.append(AsrTextDataset(utt_ids, text, dictionary, append_eos=False)) + + logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1]))) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 + assert len(src_datasets) == len(text_datasets) or len(text_datasets) == 0 + + feat_dim = src_datasets[0].feat_dim + + if len(src_datasets) == 1: + src_dataset = src_datasets[0] + tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None + text_dataset = text_datasets[0] if len(text_datasets) > 0 else None + else: + for i in range(1, len(src_datasets)): + assert feat_dim == src_datasets[i].feat_dim, \ + "feature dimension does not match across multiple json files" + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + if len(tgt_datasets) > 0: + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + else: + tgt_dataset = None + if len(text_datasets) > 0: + text_dataset = ConcatDataset(text_datasets, sample_ratios) + else: + text_dataset = None + + tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None + if lf_mmi: + return AsrChainDataset( + src_dataset, src_dataset.sizes, + tgt_dataset, tgt_dataset_sizes, + text=text_dataset, + max_source_positions=max_source_positions, + max_target_positions=max_target_positions, + ) + else: + return AsrXentDataset( + src_dataset, src_dataset.sizes, + tgt_dataset, tgt_dataset_sizes, + text=text_dataset, + max_source_positions=max_source_positions, + max_target_positions=max_target_positions, + seed=seed, chunk_width=chunk_width, + chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, + label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), + ) + + +@register_task("speech_recognition_hybrid") +class SpeechRecognitionHybridTask(FairseqTask): + """ + Hybrid speech recognition with lattice-free MMI or cross-entropy loss. + Currently it dumps posteriors from neural networks' output on-the-fly or + as an ark file for Kaldi to decode. + + Args: + dictionary (~fairseq.data.AsrDictionary): dictionary for the final text + + .. note:: + + The speech recognition with lattice-free MMI task is compatible with + :mod:`speech-train`, and :mod:`dump-posteriors`. The results are not + strictly reproducible (i.e., there is some randomness among different + runs with the same exprimental setting) due to the use of `atomicAdd` + while accumulating gradients w.r.t. pdf-ids in backprop of LF-MMI loss. + See https://pytorch.org/docs/stable/notes/randomness.html for details. + + The speech recognition task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.speech_recognition_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + parser.add_argument("data", help="path to data directory") + parser.add_argument("--dict", default=None, type=str, + help="path to the dictionary") + parser.add_argument("--non-lang-syms", default=None, type=str, + help="path to a file listing non-linguistic symbols, e.g., " + "etc. One entry per line. To be filtered out when calculating WER/CER.") + parser.add_argument("--wer-output-filter", default=None, type=str, + help="path to wer_output_filter file for WER evaluation") + parser.add_argument("--max-source-positions", default=1024, type=int, metavar="N", + help="max number of frames in the source sequence") + parser.add_argument("--max-target-positions", default=1024, type=int, metavar="N", + help="max number of tokens in the target sequence") + parser.add_argument("--upsample-primary", default=1, type=int, + help="amount to upsample primary dataset") + parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", + help="feature input channels") + parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", + help="SpecAugment config string. If not None and not empty, " + "then apply SpecAugment. Should be an evaluatable expression of " + "a python dict. See speech_tools.specaug_interpolate.specaug() for " + "all allowed arguments. Argments not appearing in this string " + "will take on their default values") + + parser.add_argument("--num-targets", type=int, metavar="N", + help="number of targets for training (e.g., num pdf-ids)") + parser.add_argument("--initial-state-prior-file", default=None, type=str, metavar="FILE", + help="path to the file containing initial state prior. Only relevant " + "with cross-entropy training") + parser.add_argument("--state-prior-update-interval", default=None, type=int, metavar="N", + help="state prior estimate will be updated every this " + "number of updates during training. If None, then use " + "the initial value estimated from the alignments. Only relevant with " + "cross-entropy training") + parser.add_argument("--state-prior-update-smoothing", default=0.1, type=float, metavar="D", + help="smoothing factor while updating state prior estimate. Only " + "relevant with cross-entropy training") + parser.add_argument("--chunk-width", default=None, type=int, metavar="D", + help="chunk width for train/test data. Only relevant with chunk-wise " + "training (including both cross-entropy and Lattice-free MMI). " + "Do utterance-wise training/test if not specified") + parser.add_argument("--chunk-left-context", default=0, type=int, metavar="D", + help="number of frames appended to the left of a chunk") + parser.add_argument("--chunk-right-context", default=0, type=int, metavar="D", + help="number of frames appended to the right of a chunk") + parser.add_argument("--label-delay", default=0, type=int, metavar="D", + help="offet of alignments as prediction labels. Maybe useful " + "in archs such as asymmetric convolution, unidirectional LSTM, etc. " + "It can be negative. Only relevant with chunk-wise cross-entropy training") + # fmt: off + + @classmethod + def load_dictionary(cls, filename, non_lang_syms=None): + """Load the dictionary from the filename + Args: + filename (str): the filename + non_lang_syms (str): non_lang_syms filename + """ + return AsrDictionary.load(filename, f_non_lang_syms=non_lang_syms) + + @classmethod + def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8): + """Disable this method + """ + raise NotImplementedError + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.feat_in_channels = args.feat_in_channels + self.specaugment_config = args.specaugment_config + self.num_targets = args.num_targets + self.training_stage = hasattr(args, "valid_subset") + + # the following attributes are related to state_prior estimate + self.initial_state_prior = None + if args.initial_state_prior_file is not None: # only relevant for Xent training, used in models + self.initial_state_prior = kaldi_io.read_vec_flt(args.initial_state_prior_file) + self.initial_state_prior = torch.from_numpy(self.initial_state_prior) + assert self.initial_state_prior.size(0) == self.num_targets, \ + "length of initial_state_prior ({}) != num_targets ({})".format( + self.initial_state_prior.size(0), self.num_targets + ) + self.state_prior_update_interval = args.state_prior_update_interval + if self.state_prior_update_interval is None and self.initial_state_prior is not None: + logger.info("state prior will not be updated during training") + self.state_prior_update_smoothing = args.state_prior_update_smoothing + self.averaged_state_post = None # state poterior will be saved here before commited as new state prior + + # the following 4 options are for chunk-wise training/test (including Xent and LF-MMI) + self.chunk_width = args.chunk_width + self.chunk_left_context = args.chunk_left_context + self.chunk_right_context = args.chunk_right_context + self.label_delay = args.label_delay # only for chunk-wise Xent training + + torch.backends.cudnn.deterministic = True + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + # load dictionaries + dict_path = args.dict + dictionary = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) if \ + dict_path is not None else None + if dictionary is not None: + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + self.datasets[split] = get_asr_dataset_from_json( + data_path, split, self.dictionary, + combine=combine, + upsample_primary=self.args.upsample_primary, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + lf_mmi=(self.args.criterion == "lattice_free_mmi"), + seed=self.args.seed, specaugment_config=self.specaugment_config, + chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, + chunk_left_context=self.chunk_left_context, chunk_right_context=self.chunk_right_context, + label_delay=self.label_delay, + ) + + src_dataset = self.datasets[split].src + self.feat_dim = src_dataset.feat_dim if not isinstance(src_dataset, ConcatDataset) \ + else src_dataset.datasets[0].feat_dim + + def build_generator(self, models, args): + if args.score_reference: + args.score_reference = False + logger.warning( + "--score-reference is not applicable to speech recognition, ignoring it." + ) + from espresso.tools.generate_log_probs_for_decoding import GenerateLogProbsForDecoding + apply_log_softmax = getattr(args, "apply_log_softmax", False) + return GenerateLogProbsForDecoding(models, apply_log_softmax=apply_log_softmax) + + def build_dataset_for_inference(self, src_tokens, src_lengths): + return AsrChainDataset(src_tokens, src_lengths) + + def inference_step(self, generator, models, sample): + with torch.no_grad(): + return generator.generate(models, sample) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + state_post = [] + for log in logging_outputs: + post = log.get("state_post", None) + if post is not None: + state_post.append(post) + if len(state_post) > 0: + # collect state priors from all workers and do weighted average + weights = state_post[0].new([log.get("ntokens", 0) for log in logging_outputs]) + weights = weights / weights.sum() # N + with torch.no_grad(): + stacked_state_post = torch.stack(state_post, dim=1) # V x N + self.averaged_state_post = stacked_state_post.mv(weights) # V + else: + self.averaged_state_post = None + + def update_state_prior(self, model): + if self.averaged_state_post is not None: + assert hasattr(model, "update_state_prior") + model.update_state_prior(self.averaged_state_post, self.state_prior_update_smoothing) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.AsrDictionary`.""" + # Note: padding idx for criterions would be self.target_dictionary.pad() if it + # returns not None. + return None diff --git a/espresso/tools/.gitignore b/espresso/tools/.gitignore index 6fc07e5f7..67d77be3d 100644 --- a/espresso/tools/.gitignore +++ b/espresso/tools/.gitignore @@ -1 +1,3 @@ kaldi +openfst* +pychain diff --git a/espresso/tools/Makefile b/espresso/tools/Makefile index 5cba5e3e9..81ca5fb0f 100644 --- a/espresso/tools/Makefile +++ b/espresso/tools/Makefile @@ -1,4 +1,19 @@ KALDI = +PYTHON_DIR = ~/anaconda3/bin + +CXX ?= g++ + +WGET ?= wget + +# Note: OpenFst requires a relatively recent C++ compiler with C++11 support, +# e.g. g++ >= 4.7, Apple clang >= 5.0 or LLVM clang >= 3.3. +OPENFST_VERSION ?= 1.7.5 + +# Default features configured for OpenFST; can be overridden in the make command line. +OPENFST_COMFIGURE ?= --enable-static --enable-shared --enable-ngram-fsts + +CPPFLAGS ?= -D_GLIBCXX_USE_CXX11_ABI=0 +CXXFLAGS ?= -D_GLIBCXX_USE_CXX11_ABI=0 .PHONY: all clean @@ -14,5 +29,53 @@ kaldi: cd kaldi/src; ./configure --shared --use-cuda=no; $(MAKE) depend; $(MAKE) all endif -clean: +clean: openfst_cleaned + rm -rf pychain rm -rf kaldi + +openfst_cleaned: + $(MAKE) -C openfst-$(OPENFST_VERSION) clean + +.PHONY: openfst # so target will be made even though "openfst" exists. +openfst: openfst_compiled openfst-$(OPENFST_VERSION)/lib + -rm -f openfst + -ln -s openfst-$(OPENFST_VERSION) openfst + +.PHONY: openfst_compiled +openfst_compiled: openfst-$(OPENFST_VERSION)/Makefile + $(MAKE) -C openfst-$(OPENFST_VERSION) install MAKEOVERRIDES= + +openfst-$(OPENFST_VERSION)/lib: | openfst-$(OPENFST_VERSION)/Makefile + -cd openfst-$(OPENFST_VERSION) && [ -d lib64 ] && [ ! -d lib ] && ln -s lib64 lib + +# Add the -O flag to CXXFLAGS on cygwin as it can fix the compilation error +# "file too big". +ifeq ($(OSTYPE),cygwin) + # Note: OSTYPE path is probably dead for latest cygwin64 (installed on 2016/11/11). + openfst_add_CXXFLAGS = -O -Wa,-mbig-obj +else ifeq ($(OS),Windows_NT) + # This new OS path is confirmed working on Windows 10 / Cygwin64. + openfst_add_CXXFLAGS = -O -Wa,-mbig-obj +else + openfst_add_CXXFLAGS = +endif + +openfst-$(OPENFST_VERSION)/Makefile: openfst-$(OPENFST_VERSION) + cd openfst-$(OPENFST_VERSION)/ && \ + ./configure --prefix=`pwd` $(OPENFST_CONFIGURE) CXX="$(CXX)" CPPFLAGS="$(CPPFLAGS)" CXXFLAGS="$(CXXFLAGS) $(openfst_add_CXXFLAGS)" LDFLAGS="$(LDFLAGS)" LIBS="-ldl" + +openfst-$(OPENFST_VERSION): openfst-$(OPENFST_VERSION).tar.gz + tar xozf openfst-$(OPENFST_VERSION).tar.gz + +openfst-$(OPENFST_VERSION).tar.gz: + $(WGET) -T 10 -t 1 http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-$(OPENFST_VERSION).tar.gz || \ + $(WGET) -T 10 -t 3 https://www.openslr.org/resources/2/openfst-$(OPENFST_VERSION).tar.gz; + +.PHONY: pychain +pychain: + test -d pychain || git clone https://github.com/YiwenShaoStephen/pychain.git + export OPENFST_PATH=`pwd`/openfst && \ + export LD_LIBRARY_PATH=`pwd`/openfst/lib:$$LD_LIBRARY_PATH && \ + export PATH=$(PYTHON_DIR):$$PATH && \ + cd pychain/openfst_binding && python3 setup.py install && \ + cd ../pytorch_binding && python3 setup.py install diff --git a/espresso/tools/asr_prep_json.py b/espresso/tools/asr_prep_json.py index f39755bfa..3f4960abc 100755 --- a/espresso/tools/asr_prep_json.py +++ b/espresso/tools/asr_prep_json.py @@ -43,6 +43,12 @@ def main(): help="path(s) to scp feature file(s)") parser.add_argument("--token-text-files", nargs="+", default=None, help="path(s) to token_text file(s)") + parser.add_argument("--text-files", nargs="+", default=None, + help="path(s) to text file(s)") + parser.add_argument("--numerator-fst-files", nargs="+", default=None, + help="path(s) to numerator fst file(s)") + parser.add_argument("--alignment-files", nargs="+", default=None, + help="path(s) to alignment file(s)") parser.add_argument("--utt2num-frames-files", nargs="+", default=None, help="path(s) to utt2num_frames file(s)") parser.add_argument("--output", required=True, type=argparse.FileType("w"), @@ -54,6 +60,12 @@ def main(): obj = read_file(obj, "feat", str, *(args.feat_files)) if args.token_text_files is not None: obj = read_file(obj, "token_text", str, *(args.token_text_files)) + if args.text_files is not None: + obj = read_file(obj, "text", str, *(args.text_files)) + if args.numerator_fst_files is not None: + obj = read_file(obj, "numerator_fst", str, *(args.numerator_fst_files)) + if args.alignment_files is not None: + obj = read_file(obj, "alignment", str, *(args.alignment_files)) if args.utt2num_frames_files is not None: obj = read_file(obj, "utt2num_frames", int, *(args.utt2num_frames_files)) diff --git a/espresso/tools/estimate_initial_state_prior_from_alignments.py b/espresso/tools/estimate_initial_state_prior_from_alignments.py new file mode 100755 index 000000000..a1d111106 --- /dev/null +++ b/espresso/tools/estimate_initial_state_prior_from_alignments.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import sys + +import numpy as np + +try: + import kaldi_io +except ImportError: + raise ImportError('Please install kaldi_io with: pip install kaldi_io') + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + stream=sys.stdout, +) +logger = logging.getLogger("espresso.tools.estimate_initial_state_prior_from_alignments") + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Obtain initial state prior from alignments") + # fmt: off + parser.add_argument("--alignment-files", nargs="+", required=True, + help="path(s) to alignment file(s)") + parser.add_argument("--prior-dim", required=True, type=int, + help="state prior dimension, i.e., the number of states") + parser.add_argument("--prior-floor", type=float, default=5.0e-6, + help="floor for the state prior") + parser.add_argument("--output", required=True, type=str, + help="output path") + # fmt: on + return parser + + +def main(args): + assert args.prior_floor > 0.0 and args.prior_floor < 1.0 + prior = np.zeros((args.prior_dim,), dtype=np.int32) + for path in args.alignment_files: + with open(path, "r", encoding="utf-8") as f: + for line in f: + _, rxfile = line.strip().split(None, 1) + try: + ali = kaldi_io.read_vec_int(rxfile) + except Exception: + raise Exception("failed to read int vector {}.".format(rxfile)) + assert ali is not None and isinstance(ali, np.ndarray) + for id in ali: + prior[id] += 1 + prior = np.maximum(prior / float(np.sum(prior)), args.prior_floor) # normalize and floor + prior = prior / float(np.sum(prior)) # normalize again + kaldi_io.write_vec_flt(args.output, prior) + + logger.info("Saved the initial state prior estimate in {}".format(args.output)) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/espresso/tools/generate_log_probs_for_decoding.py b/espresso/tools/generate_log_probs_for_decoding.py new file mode 100644 index 000000000..12eee3e61 --- /dev/null +++ b/espresso/tools/generate_log_probs_for_decoding.py @@ -0,0 +1,68 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class GenerateLogProbsForDecoding(nn.Module): + def __init__(self, models, retain_dropout=False, apply_log_softmax=False): + """Generate the neural network's output intepreted as log probabilities + for decoding with Kaldi. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + retain_dropout (bool, optional): use dropout when generating + (default: False) + apply_log_softmax (bool, optional): apply log-softmax on top of the + network's output (default: False) + """ + super().__init__() + from fairseq.sequence_generator import EnsembleModel + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.retain_dropout = retain_dropout + self.apply_log_softmax = apply_log_softmax + + if not self.retain_dropout: + self.model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + """Generate a batch of translations. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + """ + self.model.reset_incremental_state() + return self._generate(sample, **kwargs) + + def _generate(self, sample: Dict[str, Dict[str, Tensor]], **kwargs): + encoder_input = {k: v for k, v in sample["net_input"].items()} + src_tokens = encoder_input["src_tokens"] + bsz = src_tokens.size(0) + encoder_outs = self.model.forward_encoder( + src_tokens=encoder_input["src_tokens"], + src_lengths=encoder_input["src_lengths"], + ) + logits = encoder_outs[0].encoder_out.transpose(0, 1).float() # T x B x V -> B x T x V + assert logits.size(0) == bsz + padding_mask = encoder_outs[0].encoder_padding_mask.t() \ + if encoder_outs[0].encoder_padding_mask is not None else None + if self.apply_log_softmax: + return F.log_softmax(logits, dim=-1), padding_mask + return logits, padding_mask diff --git a/examples/asr_wsj/conf/mfcc_hires.conf b/examples/asr_wsj/conf/mfcc_hires.conf new file mode 100644 index 000000000..434834a67 --- /dev/null +++ b/examples/asr_wsj/conf/mfcc_hires.conf @@ -0,0 +1,10 @@ +# config for high-resolution MFCC features, intended for neural network training +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--num-mel-bins=40 # similar to Google's setup. +--num-ceps=40 # there is no dimensionality reduction. +--low-freq=20 # low cutoff frequency for mel bins... this is high-bandwidth data, so + # there might be some information at the low end. +--high-freq=-400 # high cutoff frequently, relative to Nyquist of 8000 (=7600) diff --git a/examples/asr_wsj/local/common_data_prep.sh b/examples/asr_wsj/local/common_data_prep.sh new file mode 100755 index 000000000..7a7a560cf --- /dev/null +++ b/examples/asr_wsj/local/common_data_prep.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Copyright (c) Yiwen Shao, Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The common data preparation script for hybrid systems + +set -euo pipefail + +stage=-10 +nj=30 +train_set=train_si284 +test_set="test_dev93 test_eval92" + +wsj0= +wsj1= +if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then + wsj0=/export/corpora5/LDC/LDC93S6B + wsj1=/export/corpora5/LDC/LDC94S13B +fi + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + + +if [ $stage -le -4 ]; then + # data preparation + local/wsj_data_prep.sh $wsj0/??-{?,??}.? $wsj1/??-{?,??}.? + local/wsj_prepare_dict.sh --dict-suffix "_nosp" + utils/prepare_lang.sh data/local/dict_nosp \ + "" data/local/lang_tmp_nosp data/lang_nosp + local/wsj_format_data.sh --lang-suffix "_nosp" + echo "Done formatting the data." + + local/wsj_extend_dict.sh --dict-suffix "_nosp" $wsj1/13-32.1 + utils/prepare_lang.sh data/local/dict_nosp_larger \ + "" data/local/lang_tmp_nosp_larger \ + data/lang_nosp_bd + local/wsj_train_lms.sh --dict-suffix "_nosp" + local/wsj_format_local_lms.sh --lang-suffix "_nosp" + echo "Done exteding the dictionary and formatting LMs." +fi + +if [ $stage -le -3 ]; then + # make MFCC features for the test data + echo "$0: extracting MFCC features for the test sets" + for dataset in $test_set; do + mv data/$dataset data/${dataset}_hires + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj \ + --mfcc-config conf/mfcc_hires.conf data/${dataset}_hires + steps/compute_cmvn_stats.sh data/${dataset}_hires + done +fi + +if [ $stage -le -2 ]; then + echo "$0: perturbing the training data" + utils/data/get_utt2dur.sh data/$train_set + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp + utils/copy_data_dir.sh data/${train_set}_sp data/${train_set}_sp_hires + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires +fi + +if [ $stage -le -1 ]; then + echo "$0: extracting MFCC features for the training data" + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj \ + --mfcc-config conf/mfcc_hires.conf data/${train_set}_sp_hires + steps/compute_cmvn_stats.sh data/${train_set}_sp_hires + utils/fix_data_dir.sh data/${train_set}_sp_hires +fi + +exit 0; diff --git a/examples/asr_wsj/local/score.sh b/examples/asr_wsj/local/score.sh new file mode 120000 index 000000000..0fc3566e7 --- /dev/null +++ b/examples/asr_wsj/local/score.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/score.sh \ No newline at end of file diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh index 863f5de3e..a190accd1 100644 --- a/examples/asr_wsj/path.sh +++ b/examples/asr_wsj/path.sh @@ -11,5 +11,6 @@ export LC_ALL=C export PATH=~/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH -export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PYTHONPATH +export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/openfst/lib:$LD_LIBRARY_PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$MAIN_ROOT/espresso/tools/pychain:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh new file mode 100755 index 000000000..79641909b --- /dev/null +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -0,0 +1,238 @@ +#!/bin/bash +# Copyright (c) Yiming Wang, Yiwen Shao +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +set -e -o pipefail + +stage=-10 +ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid + +# model and data related +affix= +lang=data/lang_chain_e2e +tree_dir=exp/chain/e2e_tree # it's actually just a trivial tree (no tree building) +whole_train_set=train_si284_sp # will be split into train_set and valid_set +train_set=train_si284_novalid_spe2e +valid_set=train_si284_valid_spe2e +test_set="test_dev93 test_eval92" +dumpdir=data/dump # directory to dump full features +checkpoint=checkpoint_best.pt + +wsj0= +wsj1= +if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then + wsj0=/export/corpora5/LDC/LDC93S6B + wsj1=/export/corpora5/LDC/LDC94S13B +fi + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +dir=exp/tdnn_chain_e2e${affix:+_$affix} + +local/common_data_prep.sh --stage $stage --wsj0 $wsj0 --wsj1 $wsj1 || exit 1; + +if [ $stage -le 0 ]; then + echo "Stage 0: Create the $lang Directory that Has a Specific HMM Topolopy" + rm -rf $lang + cp -r data/lang_nosp $lang + silphonelist=$(cat $lang/phones/silence.csl) || exit 1; + nonsilphonelist=$(cat $lang/phones/nonsilence.csl) || exit 1; + steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >$lang/topo +fi + +if [ $stage -le 1 ]; then + echo "Stage 1: Generate Denominator Graph and Numerator Fsts" + echo "$0: Estimating a phone language model for the denominator graph..." + mkdir -p $tree_dir/log + $train_cmd $tree_dir/log/make_phone_lm.log \ + cat data/${whole_train_set}_hires/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py --between-silprob 0.1 \ + data/lang_nosp \| \ + utils/sym2int.pl -f 2- data/lang_nosp/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=2000 \ + ark:- $tree_dir/phone_lm.fst + nj=32 + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$train_cmd" \ + --shared-phones true data/${whole_train_set}_hires $lang $tree_dir + echo "$0: Making denominator fst..." + $decode_cmd $tree_dir/log/make_den_fst.log \ + chain-make-den-fst $tree_dir/tree $tree_dir/0.trans_mdl $tree_dir/phone_lm.fst \ + $tree_dir/den.fst $tree_dir/normalization.fst || exit 1 + echo "$0: Making numerator fsts..." + abs_treedir=`utils/make_absolute.sh $tree_dir` + $decode_cmd JOB=1:$nj $tree_dir/log/make_num_fst_e2e.JOB.log \ + chain-make-num-fst-e2e $tree_dir/0.trans_mdl $tree_dir/normalization.fst \ + scp:$tree_dir/fst.JOB.scp ark,scp:$abs_treedir/fst_nor.JOB.ark,$abs_treedir/fst_nor.JOB.scp || exit 1 + for n in $(seq $nj); do + cat $tree_dir/fst_nor.$n.scp || exit 1 + done > $tree_dir/fst_nor.scp || exit 1 +fi + +if [ ${stage} -le 2 ]; then + echo "Stage 2: Split the Whole Train Set into Train/Valid Set" + # Get list of validation utterances. + data=data/${whole_train_set}_hires + set +e + awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl 2>/dev/null | head -300 > valid_uttlist + set -e + if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. + echo "File $data/utt2uniq exists, so augmenting valid_uttlist to" + echo "include all perturbed versions of the same 'real' utterances." + mv valid_uttlist valid_uttlist.tmp + utils/utt2spk_to_spk2utt.pl $data/utt2uniq > uniq2utt + cat valid_uttlist.tmp | utils/apply_map.pl $data/utt2uniq | \ + sort | uniq | utils/apply_map.pl uniq2utt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' | sort > valid_uttlist + rm uniq2utt valid_uttlist.tmp 2>/dev/null + fi + # generate train/valid data dir + utils/filter_scp.pl --exclude valid_uttlist $data/utt2spk | cut -d" " -f1 > novalid_uttlist || exit 1 + utils/subset_data_dir.sh --utt-list novalid_uttlist $data data/${train_set}_hires || exit 1 + utils/subset_data_dir.sh --utt-list valid_uttlist $data data/${valid_set}_hires || exit 1 + + # generate train/valid numerator fst file + utils/filter_scp.pl novalid_uttlist $tree_dir/fst_nor.scp > $tree_dir/fst_novalid_nor.scp || exit 1 + utils/filter_scp.pl valid_uttlist $tree_dir/fst_nor.scp > $tree_dir/fst_valid_nor.scp || exit 1 + rm valid_uttlist novalid_uttlist 2>/dev/null + + # not all fsts can be generated successfully, just filter out those not having the fst + for dataset in $train_set $valid_set; do + tag=novalid && [[ "$dataset" == "$valid_set" ]] && tag=valid + cp data/${dataset}_hires/feats.scp data/${dataset}_hires/feats.scp.tmp + utils/filter_scp.pl $tree_dir/fst_${tag}_nor.scp data/${dataset}_hires/feats.scp.tmp \ + > data/${dataset}_hires/feats.scp || exit 1 + rm data/${dataset}_hires/feats.scp.tmp 2>/dev/null + utils/fix_data_dir.sh data/${dataset}_hires || exit 1 + done +fi + +if [ ${stage} -le 3 ]; then + echo "Stage 3: Dump Feature" + for dataset in $train_set $valid_set $test_set; do + nj=8 + utils/split_data.sh data/${dataset}_hires $nj + sdata=data/${dataset}_hires/split$nj + mkdir -p $dumpdir/${dataset}_hires; abs_featdir=`utils/make_absolute.sh $dumpdir/${dataset}_hires` + $train_cmd JOB=1:$nj $abs_featdir/log/dump_feature.JOB.log \ + apply-cmvn --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp \ + scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=true --compression-method=2 ark:- \ + ark,scp:$abs_featdir/feats.JOB.ark,$abs_featdir/feats.JOB.scp || exit 1 + for n in $(seq $nj); do + cat $abs_featdir/feats.$n.scp || exit 1 + done > $abs_featdir/feats.scp || exit 1 + rm $abs_featdir/feats.*.scp 2>/dev/null + cat data/${dataset}_hires/utt2num_frames > $abs_featdir/utt2num_frames || exit 1 + cat data/${dataset}_hires/utt2spk > $abs_featdir/utt2spk || exit 1 + done +fi + +if [ ${stage} -le 4 ]; then + echo "Stage 4: Make Graphs" + for lmtype in tgpr bd_tgpr; do + utils/lang/check_phones_compatible.sh \ + data/lang_nosp_test_$lmtype/phones.txt $lang/phones.txt + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_nosp_test_$lmtype $tree_dir $tree_dir/graph_$lmtype || exit 1 + done +fi + +if [ ${stage} -le 5 ]; then + echo "Stage 5: Dump Json Files" + train_feat=$dumpdir/${train_set}_hires/feats.scp + train_fst=${tree_dir}/fst_novalid_nor.scp + train_text=data/${train_set}_hires/text + train_utt2num_frames=data/${train_set}_hires/utt2num_frames + valid_feat=$dumpdir/${valid_set}_hires/feats.scp + valid_fst=${tree_dir}/fst_valid_nor.scp + valid_text=data/${valid_set}_hires/text + valid_utt2num_frames=data/${valid_set}_hires/utt2num_frames + mkdir -p data/chain_e2e + asr_prep_json.py --feat-files $train_feat --numerator-fst-files $train_fst --text-files $train_text \ + --utt2num-frames-files $train_utt2num_frames --output data/chain_e2e/train.json + asr_prep_json.py --feat-files $valid_feat --numerator-fst-files $valid_fst --text-files $valid_text \ + --utt2num-frames-files $valid_utt2num_frames --output data/chain_e2e/valid.json + for dataset in $test_set; do + nj=$(wc -l &1 | tee $log_file +fi + +if [ ${stage} -le 7 ]; then + echo "Stage 7: Decoding" + rm $dir/.error 2>/dev/null || true + queue_opt="--num-threads 4" + path=$dir/$checkpoint + for dataset in $test_set; do + ( + data_affix=$(echo $dataset | sed s/test_//) + nj=$(wc -l $dir/decode_${lmtype}_${data_affix}/lat.JOB.gz" || exit 1 + local/score.sh --cmd "$decode_cmd" data/${dataset}_hires $graph_dir $dir/decode_${lmtype}_${data_affix} || exit 1 + echo $nj > $dir/decode_${lmtype}_${data_affix}/num_jobs + done + steps/lmrescore.sh --cmd "$decode_cmd" --self-loop-scale 1.0 --mode 3 data/lang_nosp_test_{tgpr,tg} \ + data/${dataset}_hires $dir/decode_{tgpr,tg}_${data_affix} || exit 1 + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_nosp_test_bd_{tgpr,fgconst} \ + data/${dataset}_hires $dir/decode_bd_tgpr_${data_affix}{,_fg} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 + for dataset in $test_set; do + data_affix=$(echo $dataset | sed s/test_//) + for x in $dir/decode_{tgpr_${data_affix},tg_${data_affix},bd_tgpr_${data_affix},bd_tgpr_${data_affix}_fg}; do + grep WER $x/wer_* | utils/best_wer.sh + done + done +fi diff --git a/examples/asr_wsj/run_xent.sh b/examples/asr_wsj/run_xent.sh new file mode 100755 index 000000000..f5a8fc1c5 --- /dev/null +++ b/examples/asr_wsj/run_xent.sh @@ -0,0 +1,217 @@ +#!/bin/bash +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# need to make the following soft links to corresponding dirs in Kaldi: +# ln -s /egs/wsj/s5/exp/tri4b exp/tri4b +# ln -s /egs/wsj/s5/exp/tri4b_ali_train_si284_sp data/tri4b_ali_train_si284_sp +# ln -s /egs/wsj/s5/data/lang_test_tgpr data/lang_test_tgpr +# ln -s /egs/wsj/s5/data/lang_test_tg data/lang_test_tg +# ln -s /egs/wsj/s5/data/lang_test_bd_tgpr data/lang_test_bd_tgpr +# ln -s /egs/wsj/s5/data/lang_test_bd_fgconst data/lang_test_bd_fgconst +# ln -s /egs/wsj/s5/data/train_si284_sp_hires data/train_si284_sp_hires +# ln -s /egs/wsj/s5/data/test_dev93_hires data/test_dev93_hires +# ln -s /egs/wsj/s5/data/test_eval92_hires data/test_eval92_hires + +set -e -o pipefail + +stage=0 +ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid + +# model and data related +affix= +gmm=tri4b +lang=data/lang_test +whole_train_set=train_si284_sp # will be split into train_set and valid_set +train_set=train_si284_novalid_sp +valid_set=train_si284_valid_sp +test_set="test_dev93 test_eval92" +dumpdir=data/dump # directory to dump full features +checkpoint=checkpoint_best.pt + + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +dir=exp/tdnn_xent${affix:+_$affix} + +if [ ${stage} -le 0 ]; then + echo "Stage 0: Split the Whole Train Set into Train/Valid Data and Ali Dirs" + # Get list of validation utterances. + data=data/${whole_train_set}_hires + set +e + awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl 2>/dev/null | head -300 > valid_uttlist + set -e + if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. + echo "File $data/utt2uniq exists, so augmenting valid_uttlist to" + echo "include all perturbed versions of the same 'real' utterances." + mv valid_uttlist valid_uttlist.tmp + utils/utt2spk_to_spk2utt.pl $data/utt2uniq > uniq2utt + cat valid_uttlist.tmp | utils/apply_map.pl $data/utt2uniq | \ + sort | uniq | utils/apply_map.pl uniq2utt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' | sort > valid_uttlist + rm uniq2utt valid_uttlist.tmp 2>/dev/null + fi + # generate train/valid data dir + utils/filter_scp.pl --exclude valid_uttlist $data/utt2spk | cut -d" " -f1 > novalid_uttlist || exit 1 + utils/subset_data_dir.sh --utt-list novalid_uttlist $data data/${train_set}_hires || exit 1 + utils/subset_data_dir.sh --utt-list valid_uttlist $data data/${valid_set}_hires || exit 1 + rm valid_uttlist novalid_uttlist 2>/dev/null + + # generate train/valid ali dir + steps/subset_ali_dir.sh $data data/${train_set}_hires data/${gmm}_ali_${whole_train_set} \ + data/${gmm}_ali_${train_set} || exit 1 + steps/subset_ali_dir.sh $data data/${valid_set}_hires data/${gmm}_ali_${whole_train_set} \ + data/${gmm}_ali_${valid_set} || exit 1 +fi + +if [ ${stage} -le 1 ]; then + echo "Stage 1: Convert Alignments from transition-ids to pdf-ids" + for dataset in $train_set $valid_set; do + abs_alidir=`utils/make_absolute.sh data/${gmm}_ali_$dataset` + nj=$(cat ${abs_alidir}/num_jobs) + $decode_cmd JOB=1:$nj ${abs_alidir}/log/ali_to_pdf.JOB.log \ + ali-to-pdf ${abs_alidir}/final.mdl \ + "ark,s,cs:gunzip -c ${abs_alidir}/ali.JOB.gz |" \ + ark,scp:${abs_alidir}/ali_pdf.JOB.ark,${abs_alidir}/ali_pdf.JOB.scp || exit 1 + for n in $(seq $nj); do + cat ${abs_alidir}/ali_pdf.$n.scp || exit 1 + done > ${abs_alidir}/ali_pdf.scp || exit 1 + rm ${abs_alidir}/ali_pdf.*.scp 2>/dev/null + + # not all alignments can be generated successfully, just filter out those not having the alignment + cp data/${dataset}_hires/feats.scp data/${dataset}_hires/feats.scp.tmp + utils/filter_scp.pl ${abs_alidir}/ali_pdf.scp data/${dataset}_hires/feats.scp.tmp \ + > data/${dataset}_hires/feats.scp || exit 1 + rm data/${dataset}_hires/feats.scp.tmp 2>/dev/null + utils/fix_data_dir.sh data/${dataset}_hires || exit 1 + done +fi + +if [ ${stage} -le 2 ]; then + echo "Stage 2: Dump Feature" + for dataset in $train_set $valid_set $test_set; do + nj=8 + utils/split_data.sh data/${dataset}_hires $nj + sdata=data/${dataset}_hires/split$nj + mkdir -p $dumpdir/${dataset}_hires; abs_featdir=`utils/make_absolute.sh $dumpdir/${dataset}_hires` + $train_cmd JOB=1:$nj $abs_featdir/log/dump_feature.JOB.log \ + apply-cmvn --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp \ + scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=true --compression-method=2 ark:- \ + ark,scp:$abs_featdir/feats.JOB.ark,$abs_featdir/feats.JOB.scp || exit 1 + for n in $(seq $nj); do + cat $abs_featdir/feats.$n.scp || exit 1 + done > $abs_featdir/feats.scp || exit 1 + rm $abs_featdir/feats.*.scp 2>/dev/null + cat data/${dataset}_hires/utt2num_frames > $abs_featdir/utt2num_frames || exit 1 + cat data/${dataset}_hires/utt2spk > $abs_featdir/utt2spk || exit 1 + done +fi + +if [ ${stage} -le 3 ]; then + echo "Stage 3: Make Graphs" + for lmtype in tgpr bd_tgpr; do + utils/mkgraph.sh ${lang}_$lmtype exp/$gmm exp/$gmm/graph_$lmtype || exit 1 + done +fi + +num_targets=$(tree-info data/${gmm}_ali_${train_set}/tree | grep num-pdfs | awk '{print $2}') +state_prior_file=data/xent/state_prior.vec + +if [ ${stage} -le 4 ]; then + echo "Stage 4: Dump Json Files and Estimate Initial State Prior from Alignments" + train_feat=$dumpdir/${train_set}_hires/feats.scp + train_ali=data/${gmm}_ali_${train_set}/ali_pdf.scp + train_text=data/${train_set}_hires/text + train_utt2num_frames=data/${train_set}_hires/utt2num_frames + valid_feat=$dumpdir/${valid_set}_hires/feats.scp + valid_ali=data/${gmm}_ali_${valid_set}/ali_pdf.scp + valid_text=data/${valid_set}_hires/text + valid_utt2num_frames=data/${valid_set}_hires/utt2num_frames + mkdir -p data/xent + asr_prep_json.py --feat-files $train_feat --alignment-file $train_ali --text-files $train_text --utt2num-frames-files $train_utt2num_frames --output data/xent/train.json + asr_prep_json.py --feat-files $valid_feat --alignment-file $valid_ali --text-files $valid_text --utt2num-frames-files $valid_utt2num_frames --output data/xent/valid.json + for dataset in $test_set; do + nj=$(wc -l &1 | tee $log_file +fi + +if [ ${stage} -le 6 ]; then + echo "Stage 6: Decoding" + rm $dir/.error 2>/dev/null || true + queue_opt="--num-threads 4" + path=$dir/$checkpoint + for dataset in $test_set; do + ( + data_affix=$(echo $dataset | sed s/test_//) + nj=$(wc -l $dir/decode_${lmtype}_${data_affix}/lat.JOB.gz" || exit 1 + local/score.sh --cmd "$decode_cmd" data/${dataset}_hires $graph_dir $dir/decode_${lmtype}_${data_affix} || exit 1 + echo $nj > $dir/decode_${lmtype}_${data_affix}/num_jobs + done + steps/lmrescore.sh --cmd "$decode_cmd" --mode 3 ${lang}_{tgpr,tg} \ + data/${dataset}_hires $dir/decode_{tgpr,tg}_${data_affix} || exit 1 + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" ${lang}_bd_{tgpr,fgconst} \ + data/${dataset}_hires $dir/decode_bd_tgpr_${data_affix}{,_fg} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 + for dataset in $test_set; do + data_affix=$(echo $dataset | sed s/test_//) + for x in $dir/decode_{tgpr_${data_affix},tg_${data_affix},bd_tgpr_${data_affix},bd_tgpr_${data_affix}_fg}; do + grep WER $x/wer_* | utils/best_wer.sh + done + done +fi diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 7a8aac47f..e6708a1f7 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -1034,8 +1034,8 @@ def __init__(self, models, lm_weight): @torch.jit.export def forward_encoder(self, src_tokens, src_lengths): return [ - model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) if hasattr(model, "encoder") \ - else None for model in self.models + model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) if hasattr(model, "encoder") + else None for model in self.models ] @torch.jit.export From be4fd6a2641037f22c0658eb4abf08dbd5473d1c Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 10 May 2020 15:24:17 -0400 Subject: [PATCH 081/119] code adaptation/changes according to the commits on May 10 --- espresso/tools/generate_log_probs_for_decoding.py | 11 +++++------ espresso/tools/simple_greedy_decoder.py | 13 ++++--------- examples/asr_wsj/run_chain_e2e.sh | 1 + fairseq/sequence_generator.py | 6 +++--- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/espresso/tools/generate_log_probs_for_decoding.py b/espresso/tools/generate_log_probs_for_decoding.py index 12eee3e61..1b1a44dca 100644 --- a/espresso/tools/generate_log_probs_for_decoding.py +++ b/espresso/tools/generate_log_probs_for_decoding.py @@ -52,13 +52,12 @@ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): return self._generate(sample, **kwargs) def _generate(self, sample: Dict[str, Dict[str, Tensor]], **kwargs): - encoder_input = {k: v for k, v in sample["net_input"].items()} - src_tokens = encoder_input["src_tokens"] + net_input = sample["net_input"] + src_tokens = net_input["src_tokens"] bsz = src_tokens.size(0) - encoder_outs = self.model.forward_encoder( - src_tokens=encoder_input["src_tokens"], - src_lengths=encoder_input["src_lengths"], - ) + + # compute the encoder output + encoder_outs = self.model.forward_encoder(net_input) logits = encoder_outs[0].encoder_out.transpose(0, 1).float() # T x B x V -> B x T x V assert logits.size(0) == bsz padding_mask = encoder_outs[0].encoder_padding_mask.t() \ diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index 95122ca7f..3602bc348 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -73,18 +73,13 @@ def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): @torch.no_grad() def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None): - encoder_input: Dict[str, Tensor] = {} - for k, v in sample["net_input"].items(): - if k != "prev_output_tokens": - encoder_input[k] = v - src_tokens = encoder_input["src_tokens"] + net_input = sample["net_input"] + src_tokens = net_input["src_tokens"] input_size = src_tokens.size() bsz, src_len = input_size[0], input_size[1] - encoder_outs = self.model.forward_encoder( - src_tokens=encoder_input["src_tokens"], - src_lengths=encoder_input["src_lengths"], - ) + # compute the encoder output + encoder_outs = self.model.forward_encoder(net_input) target = sample["target"] # target can only be None if not for validation assert target is not None or not self.for_validation diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh index 79641909b..669865eef 100755 --- a/examples/asr_wsj/run_chain_e2e.sh +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -159,6 +159,7 @@ if [ ${stage} -le 5 ]; then for dataset in $test_set; do nj=$(wc -l 2: - src_lengths = encoder_input["src_lengths"] + src_lengths = net_input["src_lengths"] else: # length of the source text being the character length except EndOfSentence and pad src_lengths = ( @@ -1032,9 +1032,9 @@ def __init__(self, models, lm_weight): assert self.has_encoder() @torch.jit.export - def forward_encoder(self, src_tokens, src_lengths): + def forward_encoder(self, net_input: Dict[str, Tensor]): return [ - model.encoder(src_tokens=src_tokens, src_lengths=src_lengths) if hasattr(model, "encoder") + model.encoder.forward_torchscript(net_input) if hasattr(model, "encoder") else None for model in self.models ] From 967dd34859a1520223e396a246e55f1a44d12d03 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 19 May 2020 15:00:51 -0400 Subject: [PATCH 082/119] code adaptation/changes according to the commits on May 18 --- espresso/models/speech_lstm.py | 23 +++++- espresso/models/speech_lstm_encoder_model.py | 21 +++++- espresso/speech_train.py | 77 +++++++++++++------- 3 files changed, 90 insertions(+), 31 deletions(-) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 71ba0d371..acf926a13 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -360,7 +360,24 @@ def output_lengths(self, in_lengths): return in_lengths if self.conv_layers_before is None \ else self.conv_layers_before.output_lengths(in_lengths) - def forward(self, src_tokens, src_lengths: Tensor, **unused): + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + enforce_sorted: bool = True, + **unused, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of + shape `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of + shape `(batch)` + enforce_sorted (bool, optional): if True, `src_tokens` is + expected to contain sequences sorted by length in a + decreasing order. If False, this condition is not + required. Default: True. + """ if self.left_pad: # nn.utils.rnn.pack_padded_sequence requires right-padding; # convert left-padding to right-padding @@ -390,7 +407,9 @@ def forward(self, src_tokens, src_lengths: Tensor, **unused): if self.residual and i > 0: # residual connection starts from the 2nd layer prev_x = x # pack embedded source tokens into a PackedSequence - packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data) + packed_x = nn.utils.rnn.pack_padded_sequence( + x, src_lengths.data, enforce_sorted=enforce_sorted + ) # apply LSTM packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0)) diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py index ef1bc10c5..bb201d3f2 100644 --- a/espresso/models/speech_lstm_encoder_model.py +++ b/espresso/models/speech_lstm_encoder_model.py @@ -177,8 +177,25 @@ def __init__( self.fc_out = Linear(self.output_units, num_targets, dropout=dropout_out) \ if num_targets is not None else None - def forward(self, src_tokens, src_lengths: Tensor, **unused): - out = super().forward(src_tokens, src_lengths, **unused) + def forward( + self, + src_tokens, + src_lengths: Tensor, + enforce_sorted: bool = True, + **unused, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of + shape `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of + shape `(batch)` + enforce_sorted (bool, optional): if True, `src_tokens` is + expected to contain sequences sorted by length in a + decreasing order. If False, this condition is not + required. Default: True. + """ + out = super().forward(src_tokens, src_lengths, enforce_sorted=enforce_sorted, **unused) x, encoder_padding_mask, x_lengths = out.encoder_out, out.encoder_padding_mask, out.src_lengths # determine which output frame to select for loss evaluation/test, assuming diff --git a/espresso/speech_train.py b/espresso/speech_train.py index d05eb60ae..6427fdd3d 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -48,10 +48,10 @@ def main(args, init_distributed=False): metrics.reset() # Initialize CUDA and distributed training - if torch.cuda.is_available() and not args.cpu: + if torch.cuda.is_available() and not args.cpu and not getattr(args, 'tpu', False): torch.cuda.set_device(args.device_id) np.random.seed(args.seed) - torch.manual_seed(args.seed) + utils.set_torch_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) @@ -94,7 +94,7 @@ def main(args, init_distributed=False): else: trainer = MegatronTrainer(args, task, model, criterion) - logger.info('training on {} GPUs'.format(args.distributed_world_size)) + logger.info('training on {} devices (GPUs/TPUs)'.format(args.distributed_world_size)) logger.info('max input frames per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, @@ -103,6 +103,10 @@ def main(args, init_distributed=False): # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) + if args.tpu: + import torch_xla.core.xla_model as xm + xm.rendezvous('load_checkpoint') # wait for all workers + xm.mark_step() # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -155,6 +159,19 @@ def is_better(a, b): return False +def tpu_data_loader(args, itr): + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl + xm.rendezvous('tpu_data_loader') # wait for all workers + xm.mark_step() + device = utils.get_tpu_device(args) + return iterators.CountingIterator( + pl.ParallelLoader(itr, [device]).per_device_loader(device), + start=getattr(itr, 'n', 0), + total=len(itr), + ) + + @metrics.aggregate('train') def train(args, trainer, task, epoch_itr, max_update=math.inf): """Train the model for one epoch and return validation losses.""" @@ -169,6 +186,8 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) + if getattr(args, 'tpu', False): + itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, @@ -206,7 +225,10 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): if hasattr(task, 'update_state_prior'): task.update_state_prior(trainer.get_model()) - valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) + end_of_epoch = not itr.has_next() + valid_losses = validate_and_save( + args, trainer, task, epoch_itr, valid_subsets, end_of_epoch + ) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break @@ -219,7 +241,7 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): return valid_losses -def validate_and_save(args, trainer, task, epoch_itr, valid_subsets): +def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): num_updates = trainer.get_num_updates() do_save = ( ( @@ -227,18 +249,12 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets): and num_updates > 0 and num_updates % args.save_interval_updates == 0 ) - or ( - epoch_itr.end_of_epoch() - and epoch_itr.epoch % args.save_interval == 0 - ) + or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) ) do_validate = ( ( do_save # saving requires validation - or ( - epoch_itr.end_of_epoch() - and epoch_itr.epoch % args.validate_interval == 0 - ) + or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) ) and not args.disable_validation ) @@ -254,8 +270,6 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets): def get_training_stats(stats): - if 'nll_loss' in stats and 'ppl' not in stats: - stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) return stats @@ -285,6 +299,8 @@ def validate(args, trainer, task, epoch_itr, subsets): shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) + if getattr(args, 'tpu', False): + itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, @@ -312,8 +328,6 @@ def validate(args, trainer, task, epoch_itr, subsets): def get_valid_stats(args, trainer, stats): - if 'nll_loss' in stats and 'ppl' not in stats: - stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['num_updates'] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, 'best'): key = 'best_{0}'.format(args.best_checkpoint_metric) @@ -360,16 +374,25 @@ def cli_main(modify_parser=None): else: distributed_main(args.device_id, args) elif args.distributed_world_size > 1: - # fallback for single node with multiple GPUs - assert args.distributed_world_size <= torch.cuda.device_count() - port = random.randint(10000, 20000) - args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) - args.distributed_rank = None # set based on device id - torch.multiprocessing.spawn( - fn=distributed_main, - args=(args, ), - nprocs=args.distributed_world_size, - ) + if not getattr(args, 'tpu', False): + # fallback for single node with multiple GPUs + assert args.distributed_world_size <= torch.cuda.device_count() + port = random.randint(10000, 20000) + args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + args.distributed_rank = None # set based on device id + torch.multiprocessing.spawn( + fn=distributed_main, + args=(args, ), + nprocs=args.distributed_world_size, + ) + else: + import torch_xla.distributed.xla_multiprocessing as xmp + torch.multiprocessing.set_sharing_strategy('file_system') + xmp.spawn( + fn=distributed_main, + args=(args, ), + nprocs=8, # use all 8 TPU cores + ) else: # single GPU training main(args) From af5a56479136cb6ae97715c5d1fdb2de268fcf67 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 27 May 2020 00:39:10 -0400 Subject: [PATCH 083/119] fix lf-mmi loss; code adaptation/changes according to the commits on May 27 --- espresso/criterions/lf_mmi_loss.py | 2 +- espresso/speech_train.py | 47 +++++++++++++----------------- examples/asr_librispeech/run.sh | 2 +- examples/asr_swbd/run.sh | 2 +- examples/asr_wsj/run.sh | 2 +- examples/asr_wsj/run_chain_e2e.sh | 2 +- examples/asr_wsj/run_xent.sh | 2 +- 7 files changed, 27 insertions(+), 32 deletions(-) diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index e7636a59b..5fdd74d49 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -54,7 +54,7 @@ def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) loss, _ = self.compute_loss(net_output, sample, reduce=reduce) - sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + sample_size = sample["target"].batch_size if self.sentence_avg else sample["ntokens"] logging_output = { "loss": loss.data, "ntokens": sample["ntokens"], diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 6427fdd3d..53edfcb89 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -110,7 +110,6 @@ def main(args, init_distributed=False): # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf - max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() @@ -119,8 +118,8 @@ def main(args, init_distributed=False): and epoch_itr.next_epoch_idx <= max_epoch ): # train for one epoch - valid_losses = train(args, trainer, task, epoch_itr, max_update) - if should_stop_early(args, valid_losses[0]) or trainer.get_num_updates() >= max_update: + valid_losses, should_stop = train(args, trainer, task, epoch_itr) + if should_stop: break # only use first validation loss to update the learning rate @@ -173,7 +172,7 @@ def tpu_data_loader(args, itr): @metrics.aggregate('train') -def train(args, trainer, task, epoch_itr, max_update=math.inf): +def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( @@ -205,6 +204,7 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): trainer.criterion.set_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(',') + should_stop = False for samples in progress: with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) @@ -226,10 +226,10 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): task.update_state_prior(trainer.get_model()) end_of_epoch = not itr.has_next() - valid_losses = validate_and_save( + valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) - if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: + if should_stop: break # log end-of-epoch stats @@ -238,7 +238,7 @@ def train(args, trainer, task, epoch_itr, max_update=math.inf): # reset epoch-level meters metrics.reset_meters('train') - return valid_losses + return valid_losses, should_stop def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): @@ -253,7 +253,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc ) do_validate = ( ( - do_save # saving requires validation + (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) ) and not args.disable_validation @@ -263,10 +263,19 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc valid_losses = [None] if do_validate: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) - # Save - if do_save: + + # Stopping conditions + max_update = args.max_update or math.inf + should_stop = ( + should_stop_early(args, valid_losses[0]) + or trainer.get_num_updates() >= max_update + ) + + # Save checkpoint + if do_save or should_stop: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - return valid_losses + + return valid_losses, should_stop def get_training_stats(stats): @@ -284,21 +293,7 @@ def validate(args, trainer, task, epoch_itr, subsets): valid_losses = [] for subset in subsets: # Initialize data iterator - itr = task.get_batch_iterator( - dataset=task.dataset(subset), - max_tokens=args.max_tokens_valid, - max_sentences=args.max_sentences_valid, - max_positions=utils.resolve_max_positions( - task.max_positions(), - trainer.get_model().max_positions(), - ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - seed=args.seed, - num_shards=args.distributed_world_size, - shard_id=args.distributed_rank, - num_workers=args.num_workers, - ).next_epoch_itr(shuffle=False) + itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, 'tpu', False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index f8051492e..b55e71fe4 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -230,7 +230,7 @@ if [ ${stage} -le 8 ]; then fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval $((8000/ngpus)) --log-format simple --print-training-sample-interval $((4000/ngpus)) \ - --num-workers 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ + --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 718fe10de..485a67ffd 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -271,7 +271,7 @@ if [ $stage -le 7 ]; then fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval $((3000/ngpus)) --log-format simple --print-training-sample-interval $((4000/ngpus)) \ - --num-workers 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ + --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 60c775339..d2f2925ee 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -271,7 +271,7 @@ if [ ${stage} -le 9 ]; then [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval $((800/ngpus)) --log-format simple --print-training-sample-interval $((2000/ngpus)) \ - --num-workers 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ + --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 \ diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh index 669865eef..7331815cc 100755 --- a/examples/asr_wsj/run_chain_e2e.sh +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -185,7 +185,7 @@ if [ ${stage} -le 6 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e --task speech_recognition_hybrid --seed 1 --user-dir espresso \ - --log-interval $((200/ngpus)) --log-format simple --num-workers 0 --max-tokens 120000 --max-sentences 128 \ + --log-interval $((200/ngpus)) --log-format simple --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --max-sentences 128 \ --curriculum 1 --valid-subset $valid_subset --max-sentences-valid 128 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 26 --optimizer adam --lr 0.001 --weight-decay 0.0 --start-reduce-lr-epoch 11 \ diff --git a/examples/asr_wsj/run_xent.sh b/examples/asr_wsj/run_xent.sh index f5a8fc1c5..b8c8f10a7 100755 --- a/examples/asr_wsj/run_xent.sh +++ b/examples/asr_wsj/run_xent.sh @@ -165,7 +165,7 @@ if [ ${stage} -le 5 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/xent --task speech_recognition_hybrid --seed 1 --user-dir espresso \ - --log-interval $((100/ngpus)) --log-format simple --num-workers 0 --max-tokens 160000 --max-sentences 256 \ + --log-interval $((100/ngpus)) --log-format simple --num-workers 0 --data-buffer-size 0 --max-tokens 160000 --max-sentences 256 \ --valid-subset $valid_subset --max-sentences-valid 256 --ddp-backend no_c10d \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --max-epoch 40 --optimizer adam --lr 0.001 --weight-decay 0.0 \ From 9d3b8b3624893bd8ea4faa333b6c255120590ae9 Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 16 Jun 2020 20:11:09 -0400 Subject: [PATCH 084/119] remove useless max_{source,target}_positions arguments --- espresso/data/asr_chain_dataset.py | 9 +-------- espresso/data/asr_dataset.py | 7 ------- espresso/data/asr_xent_dataset.py | 9 +-------- espresso/tasks/speech_recognition.py | 5 ----- espresso/tasks/speech_recognition_hybrid.py | 7 ------- espresso/tools/specaug_interpolate.py | 4 ++-- 6 files changed, 4 insertions(+), 37 deletions(-) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 2a68ef7ed..62b8f4cc7 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -154,25 +154,18 @@ class AsrChainDataset(FairseqDataset): tgt (espresso.data.NumeratorGraphDataset, optional): target numerator graph dataset to wrap tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) text (torch.utils.data.Dataset, optional): text dataset to wrap - max_source_positions (int, optional): max number of frames in the - source (default: 1024). - max_target_positions (int, optional): max number of tokens in the target - sentence (default: 1024) shuffle (bool, optional): shuffle dataset elements before batching (default: True) """ def __init__( - self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, - max_source_positions=1024, max_target_positions=1024, shuffle=True, + self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, shuffle=True, ): self.src = src self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.text = text - self.max_source_positions = max_source_positions - self.max_target_positions = max_target_positions self.shuffle = shuffle self.epoch = 1 num_before_matching = len(self.src.utt_ids) diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 602fc7fa0..db530fd76 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -94,10 +94,6 @@ class AsrDataset(FairseqDataset): (default: True). left_pad_target (bool, optional): pad target tensors on the left side (default: False). - max_source_positions (int, optional): max number of frames in the - source (default: 1024). - max_target_positions (int, optional): max number of tokens in the target - sentence (default: 1024) shuffle (bool, optional): shuffle dataset elements before batching (default: True) input_feeding (bool, optional): create a shifted version of the targets @@ -108,7 +104,6 @@ def __init__( self, src, src_sizes, tgt=None, tgt_sizes=None, dictionary=None, left_pad_source=False, left_pad_target=False, - max_source_positions=1024, max_target_positions=1024, shuffle=True, input_feeding=True, ): self.src = src @@ -118,8 +113,6 @@ def __init__( self.dictionary = dictionary self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target - self.max_source_positions = max_source_positions - self.max_target_positions = max_target_positions self.shuffle = shuffle self.input_feeding = input_feeding if self.tgt is not None: diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index 60d573daf..4afbaacb1 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -316,10 +316,6 @@ class AsrXentDataset(FairseqDataset): tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) tgt_vocab_size (int, optional): used for setting padding index text (torch.utils.data.Dataset, optional): text dataset to wrap - max_source_positions (int, optional): max number of frames in the - source (default: 1024). - max_target_positions (int, optional): max number of tokens in the target - sentence (default: 1024) shuffle (bool, optional): shuffle dataset elements before batching (default: True) seed (int, optional): random seed for generating a chunk from an utterance @@ -334,8 +330,7 @@ class AsrXentDataset(FairseqDataset): def __init__( self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, - max_source_positions=1024, max_target_positions=1024, shuffle=True, - seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, + shuffle=True, seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, ): self.src = src @@ -343,8 +338,6 @@ def __init__( self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None self.text = text - self.max_source_positions = max_source_positions - self.max_target_positions = max_target_positions self.shuffle = shuffle self.seed = seed self.epoch = 1 diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index f5bd5a845..0b301f424 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -30,7 +30,6 @@ def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, - max_source_positions, max_target_positions, seed=1, specaugment_config=None, ): """ @@ -113,8 +112,6 @@ def get_asr_dataset_from_json( tgt_dict, left_pad_source=False, left_pad_target=False, - max_source_positions=max_source_positions, - max_target_positions=max_target_positions, ) @@ -236,8 +233,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): data_path, split, self.tgt_dict, combine=combine, upsample_primary=self.args.upsample_primary, - max_source_positions=self.args.max_source_positions, - max_target_positions=self.args.max_target_positions, seed=self.args.seed, specaugment_config=self.specaugment_config, ) diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index dadd9404b..b17261682 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -38,7 +38,6 @@ def get_asr_dataset_from_json( data_path, split, dictionary, combine, upsample_primary, - max_source_positions, max_target_positions, lf_mmi=True, seed=1, specaugment_config=None, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, @@ -146,16 +145,12 @@ def get_asr_dataset_from_json( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, text=text_dataset, - max_source_positions=max_source_positions, - max_target_positions=max_target_positions, ) else: return AsrXentDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, text=text_dataset, - max_source_positions=max_source_positions, - max_target_positions=max_target_positions, seed=seed, chunk_width=chunk_width, chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), @@ -318,8 +313,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): data_path, split, self.dictionary, combine=combine, upsample_primary=self.args.upsample_primary, - max_source_positions=self.args.max_source_positions, - max_target_positions=self.args.max_target_positions, lf_mmi=(self.args.criterion == "lattice_free_mmi"), seed=self.args.seed, specaugment_config=self.specaugment_config, chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, diff --git a/espresso/tools/specaug_interpolate.py b/espresso/tools/specaug_interpolate.py index 2caf62ee5..d6afdcd90 100644 --- a/espresso/tools/specaug_interpolate.py +++ b/espresso/tools/specaug_interpolate.py @@ -31,7 +31,7 @@ def specaug(spec, W=80, F=27, T=70, num_freq_masks=2, num_time_masks=2, p=0.2, r T (int): maximum width of each time mask num_freq_masks (int): number of frequency masks num_time_masks (int): number of time masks - p (int): toal mask width shouldn't exeed this times num of frames + p (int): time mask width shouldn't exeed this times num of frames replace_with_zero (bool): if True, masked parts will be filled with 0, if False, filled with mean Returns: @@ -110,7 +110,7 @@ def time_mask(spec, T=40, num_masks=1, p=0.2, pad_value=0.): spec (torch.Tensor): input tensor of shape `(dim, T)` T (int): maximum width of each mask num_masks (int): number of masks - p (float): toal mask width shouldn't exeed this times num of frames + p (float): time mask width shouldn't exeed this times num of frames pad_value (float): value for padding Returns: From c8c1dfb6affe78199c2d9fcfdd0f533fafce600e Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 19 Jun 2020 01:34:20 -0400 Subject: [PATCH 085/119] code adaptation/changes according to the commits on Jun 18-23, 2020 --- espresso/models/speech_lstm.py | 130 +++++++++++++++----------- espresso/models/speech_tdnn.py | 18 +++- espresso/speech_train.py | 162 +++++++++++++++++---------------- 3 files changed, 175 insertions(+), 135 deletions(-) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index acf926a13..e0332784d 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch from torch import Tensor @@ -433,15 +433,25 @@ def forward( ) def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): - encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(1, new_order) \ - if encoder_out.encoder_padding_mask is not None else None + encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask + src_lengths: Optional[Tensor] = encoder_out.src_lengths + new_encoder_padding_mask = ( + encoder_padding_mask + if encoder_padding_mask is None + else encoder_padding_mask.index_select(1, new_order) + ) + new_src_lengths = ( + src_lengths + if src_lengths is None + else src_lengths.index_select(0, new_order) + ) return EncoderOut( encoder_out=encoder_out.encoder_out.index_select(1, new_order), - encoder_padding_mask=encoder_padding_mask, + encoder_padding_mask=new_encoder_padding_mask, encoder_embedding=None, encoder_states=None, src_tokens=None, - src_lengths=encoder_out.src_lengths.index_select(0, new_order), + src_lengths=new_src_lengths, ) def max_positions(self): @@ -517,18 +527,6 @@ def __init__( self.scheduled_sampling_rate_scheduler = scheduled_sampling_rate_scheduler - def get_cached_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): - cached_state = self.get_incremental_state(incremental_state, "cached_state") - assert cached_state is not None - prev_hiddens_ = cached_state["prev_hiddens"] - assert prev_hiddens_ is not None - prev_cells_ = cached_state["prev_cells"] - assert prev_cells_ is not None - prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] - prev_cells = [prev_cells_[j] for j in range(self.num_layers)] - input_feed = cached_state["input_feed"] # can be None for decoder-only language models - return prev_hiddens, prev_cells, input_feed - def forward( self, prev_output_tokens, @@ -691,7 +689,11 @@ def extract_features( prev_cells_tensor = torch.stack(prev_cells) cache_state = torch.jit.annotate( Dict[str, Optional[Tensor]], - {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": input_feed} + { + "prev_hiddens": prev_hiddens_tensor, + "prev_cells": prev_cells_tensor, + "input_feed": input_feed, + } ) self.set_incremental_state(incremental_state, "cached_state", cache_state) @@ -725,24 +727,40 @@ def output_layer(self, features, **kwargs): else: return features - def reorder_state(self, state: List[Tensor], new_order): - return [ - state_i.index_select(0, new_order) if state_i is not None else None - for state_i in state - ] + def get_cached_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + ) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]: + cached_state = self.get_incremental_state(incremental_state, "cached_state") + assert cached_state is not None + prev_hiddens_ = cached_state["prev_hiddens"] + assert prev_hiddens_ is not None + prev_cells_ = cached_state["prev_cells"] + assert prev_cells_ is not None + prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] + prev_cells = [prev_cells_[j] for j in range(self.num_layers)] + input_feed = cached_state["input_feed"] # can be None for decoder-only language models + return prev_hiddens, prev_cells, input_feed - def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order): - super().reorder_incremental_state(incremental_state, new_order) + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): if incremental_state is None or len(incremental_state) == 0: return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) - cached_state = (prev_hiddens, prev_cells, [input_feed]) - new_state = [self.reorder_state(state, new_order) for state in cached_state] - prev_hiddens_tensor = torch.stack(new_state[0]) - prev_cells_tensor = torch.stack(new_state[1]) + prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens] + prev_cells = [p.index_select(0, new_order) for p in prev_cells] + if input_feed is not None: + input_feed = input_feed.index_select(0, new_order) cached_state_new = torch.jit.annotate( Dict[str, Optional[Tensor]], - {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]} + { + "prev_hiddens": torch.stack(prev_hiddens), + "prev_cells": torch.stack(prev_cells), + "input_feed": input_feed, + } ) self.set_incremental_state(incremental_state, "cached_state", cached_state_new), return @@ -752,33 +770,37 @@ def masked_copy_incremental_state(self, incremental_state, another_cached_state, assert another_cached_state is None or len(another_cached_state) == 0 return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) - cached_state = (prev_hiddens, prev_cells, [input_feed]) - another_cached_state = (another_cached_state[0], another_cached_state[1], [another_cached_state[2]]) - - def mask_copy_state(state: List[Tensor], another_state: List[Tensor]): - new_state = [] - for state_i, another_state_i in zip(state, another_state): - if state_i is None: - assert another_state_i is None - new_state.append(None) - else: - assert state_i.size(0) == mask.size(0) and another_state_i is not None and \ - state_i.size() == another_state_i.size() - mask_unsqueezed = mask - for _ in range(1, len(state_i.size())): - mask_unsqueezed = mask_unsqueezed.unsqueeze(-1) - new_state.append(torch.where(mask_unsqueezed, state_i, another_state_i)) - return new_state - - new_state = [ - mask_copy_state(state, another_state) - for (state, another_state) in zip(cached_state, another_cached_state) + another_prev_hiddens, another_prev_cells, another_input_feed = \ + another_cached_state[0], another_cached_state[1], another_cached_state[2] + + def mask_copy_state(state: Optional[Tensor], another_state: Optional[Tensor]): + if state is None: + assert another_state is None + return None + else: + assert ( + state.size(0) == mask.size(0) and another_state is not None and + state.size() == another_state.size() + ) + mask_unsqueezed = mask + for _ in range(1, len(state.size())): + mask_unsqueezed = mask_unsqueezed.unsqueeze(-1) + return torch.where(mask_unsqueezed, state, another_state) + + prev_hiddens_new = [ + mask_copy_state(p, another_p) for (p, another_p) in zip(prev_hiddens, another_prev_hiddens) + ] + prev_cells_new = [ + mask_copy_state(p, another_p) for (p, another_p) in zip(prev_cells, another_prev_cells) ] - prev_hiddens_tensor = torch.stack(new_state[0]) - prev_cells_tensor = torch.stack(new_state[1]) + input_feed_new = mask_copy_state(input_feed, another_input_feed) cached_state_new = torch.jit.annotate( Dict[str, Optional[Tensor]], - {"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]} + { + "prev_hiddens": torch.stack(prev_hiddens_new), + "prev_cells": torch.stack(prev_cells_new), + "input_feed": input_feed_new, + } ) self.set_incremental_state(incremental_state, "cached_state", cached_state_new) diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index 0412dd8e0..de470ca05 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -268,15 +268,25 @@ def output_layer(self, features, **kwargs): return self.fc_out(features) # T x B x C -> T x B x V def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): - encoder_padding_mask = encoder_out.encoder_padding_mask.index_select(1, new_order) \ - if encoder_out.encoder_padding_mask is not None else None + encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask + src_lengths: Optional[Tensor] = encoder_out.src_lengths + new_encoder_padding_mask = ( + encoder_padding_mask + if encoder_padding_mask is None + else encoder_padding_mask.index_select(1, new_order) + ) + new_src_lengths = ( + src_lengths + if src_lengths is None + else src_lengths.index_select(0, new_order) + ) return EncoderOut( encoder_out=encoder_out.encoder_out.index_select(1, new_order), - encoder_padding_mask=encoder_padding_mask, + encoder_padding_mask=new_encoder_padding_mask, encoder_embedding=None, encoder_states=None, src_tokens=None, - src_lengths=encoder_out.src_lengths.index_select(0, new_order), + src_lengths=new_src_lengths, ) def max_positions(self): diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 53edfcb89..7bdd3d3af 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -16,7 +16,6 @@ import numpy as np import torch - from fairseq import ( checkpoint_utils, distributed_utils, @@ -27,28 +26,29 @@ ) from fairseq.data import iterators from fairseq.logging import meters, metrics, progress_bar -from fairseq.trainer import Trainer from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from fairseq.trainer import Trainer logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=sys.stdout, ) -logger = logging.getLogger('espresso.speech_train') +logger = logging.getLogger("espresso.speech_rain") def main(args, init_distributed=False): utils.import_user_module(args) - assert args.max_tokens is not None or args.max_sentences is not None, \ - 'Must specify batch size either with --max-tokens or --max-sentences' + assert ( + args.max_tokens is not None or args.max_sentences is not None + ), "Must specify batch size either with --max-tokens or --max-sentences" metrics.reset() # Initialize CUDA and distributed training - if torch.cuda.is_available() and not args.cpu and not getattr(args, 'tpu', False): + if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): torch.cuda.set_device(args.device_id) np.random.seed(args.seed) utils.set_torch_seed(args.seed) @@ -65,18 +65,22 @@ def main(args, init_distributed=False): task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in args.valid_subset.split(','): + for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) - logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) - logger.info('num. model params: {} (num. trained: {})'.format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - )) + logger.info( + "model {}, criterion {}".format(args.arch, criterion.__class__.__name__) + ) + logger.info( + "num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + ) + ) # (optionally) Configure quantization if args.quantization_config_path is not None: @@ -94,18 +98,22 @@ def main(args, init_distributed=False): else: trainer = MegatronTrainer(args, task, model, criterion) - logger.info('training on {} devices (GPUs/TPUs)'.format(args.distributed_world_size)) - logger.info('max input frames per GPU = {} and max sentences per GPU = {}'.format( - args.max_tokens, - args.max_sentences, - )) + logger.info( + "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) + ) + logger.info( + "max input frames per GPU = {} and max sentences per GPU = {}".format( + args.max_tokens, args.max_sentences + ) + ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) if args.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous('load_checkpoint') # wait for all workers + + xm.rendezvous("load_checkpoint") # wait for all workers xm.mark_step() # Train until the learning rate gets too small @@ -113,10 +121,7 @@ def main(args, init_distributed=False): lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while ( - lr > args.min_lr - and epoch_itr.next_epoch_idx <= max_epoch - ): + while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) if should_stop: @@ -128,10 +133,10 @@ def main(args, init_distributed=False): epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch - load_dataset=(os.pathsep in getattr(args, 'data', '')), + load_dataset=(os.pathsep in getattr(args, "data", "")), ) train_meter.stop() - logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) + logger.info("done training in {:.1f} seconds".format(train_meter.sum)) def should_stop_early(args, valid_loss): @@ -144,7 +149,7 @@ def should_stop_early(args, valid_loss): def is_better(a, b): return a > b if args.maximize_best_checkpoint_metric else a < b - prev_best = getattr(should_stop_early, 'best', None) + prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): should_stop_early.best = valid_loss should_stop_early.num_runs = 0 @@ -152,7 +157,11 @@ def is_better(a, b): else: should_stop_early.num_runs += 1 if should_stop_early.num_runs >= args.patience: - logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) + logger.info( + "early stop since valid performance hasn't improved for last {} runs".format( + args.patience + ) + ) return True else: return False @@ -161,17 +170,18 @@ def is_better(a, b): def tpu_data_loader(args, itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl - xm.rendezvous('tpu_data_loader') # wait for all workers + + xm.rendezvous("tpu_data_loader") # wait for all workers xm.mark_step() device = utils.get_tpu_device(args) return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), - start=getattr(itr, 'n', 0), + start=getattr(itr, "n", 0), total=len(itr), ) -@metrics.aggregate('train') +@metrics.aggregate("train") def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator @@ -185,7 +195,7 @@ def train(args, trainer, task, epoch_itr): else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(args, 'tpu', False): + if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, @@ -195,18 +205,18 @@ def train(args, trainer, task, epoch_itr): tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), - default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) - if hasattr(trainer.criterion, 'set_epoch'): + if hasattr(trainer.criterion, "set_epoch"): trainer.criterion.set_epoch(epoch_itr.epoch) - valid_subsets = args.valid_subset.split(',') + valid_subsets = args.valid_subset.split(",") should_stop = False - for samples in progress: - with metrics.aggregate('train_inner'): + for i, samples in enumerate(progress): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue @@ -214,17 +224,17 @@ def train(args, trainer, task, epoch_itr): # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: - stats = get_training_stats(metrics.get_smoothed_values('train_inner')) - progress.log(stats, tag='train_inner', step=num_updates) + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved - metrics.reset_meters('train_inner') + metrics.reset_meters("train_inner") # update the state prior stored in the model for cross-entropy training - if hasattr(task, 'update_state_prior'): + if hasattr(task, "update_state_prior"): task.update_state_prior(trainer.get_model()) - + end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch @@ -233,31 +243,25 @@ def train(args, trainer, task, epoch_itr): break # log end-of-epoch stats - stats = get_training_stats(metrics.get_smoothed_values('train')) - progress.print(stats, tag='train', step=num_updates) + stats = get_training_stats(metrics.get_smoothed_values("train")) + progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters - metrics.reset_meters('train') + metrics.reset_meters("train") return valid_losses, should_stop def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): num_updates = trainer.get_num_updates() do_save = ( - ( - args.save_interval_updates > 0 - and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - ) - or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) - ) + args.save_interval_updates > 0 + and num_updates > 0 + and num_updates % args.save_interval_updates == 0 + ) or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) do_validate = ( - ( - (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) - ) - and not args.disable_validation - ) + (not end_of_epoch and do_save) # validate during mid-epoch saves + or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + ) and not args.disable_validation # Validate valid_losses = [None] @@ -279,7 +283,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc def get_training_stats(stats): - stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) + stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats @@ -294,7 +298,7 @@ def validate(args, trainer, task, epoch_itr, subsets): for subset in subsets: # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) - if getattr(args, 'tpu', False): + if getattr(args, "tpu", False): itr = tpu_data_loader(args, itr) progress = progress_bar.progress_bar( itr, @@ -305,7 +309,7 @@ def validate(args, trainer, task, epoch_itr, subsets): tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), - default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), + default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics @@ -323,13 +327,12 @@ def validate(args, trainer, task, epoch_itr, subsets): def get_valid_stats(args, trainer, stats): - stats['num_updates'] = trainer.get_num_updates() - if hasattr(checkpoint_utils.save_checkpoint, 'best'): - key = 'best_{0}'.format(args.best_checkpoint_metric) + stats["num_updates"] = trainer.get_num_updates() + if hasattr(checkpoint_utils.save_checkpoint, "best"): + key = "best_{0}".format(args.best_checkpoint_metric) best_function = max if args.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, - stats[args.best_checkpoint_metric], + checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] ) return stats @@ -345,14 +348,22 @@ def print_options_meaning_changes(args): """Options that have different meanings than those in the translation task are explained here. """ - logger.info('--max-tokens is the maximum number of input frames in a batch') + logger.info("--max-tokens is the maximum number of input frames in a batch") def cli_main(modify_parser=None): parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) print_options_meaning_changes(args) + if args.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + cli_main_helper(args) + else: + cli_main_helper(args) + +def cli_main_helper(args): if args.distributed_init_method is None: distributed_utils.infer_init_method(args) @@ -369,29 +380,26 @@ def cli_main(modify_parser=None): else: distributed_main(args.device_id, args) elif args.distributed_world_size > 1: - if not getattr(args, 'tpu', False): + if not getattr(args, "tpu", False): # fallback for single node with multiple GPUs assert args.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) - args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) + args.distributed_init_method = "tcp://localhost:{port}".format(port=port) args.distributed_rank = None # set based on device id torch.multiprocessing.spawn( - fn=distributed_main, - args=(args, ), - nprocs=args.distributed_world_size, + fn=distributed_main, args=(args,), nprocs=args.distributed_world_size ) else: import torch_xla.distributed.xla_multiprocessing as xmp - torch.multiprocessing.set_sharing_strategy('file_system') + + torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( - fn=distributed_main, - args=(args, ), - nprocs=8, # use all 8 TPU cores + fn=distributed_main, args=(args,), nprocs=8 # use all 8 TPU cores ) else: # single GPU training main(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() From 027595bbfa0eae767a5d66ccb1d070b82e31c394 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 25 Jun 2020 19:54:03 -0400 Subject: [PATCH 086/119] code adaptation/changes according to the commits on Jun 24-25, 2020; fix validation loss in LSTM models --- espresso/data/__init__.py | 23 ++-- .../data/asr_bucket_pad_length_dataset.py | 91 ++++++++++++++ espresso/data/asr_chain_dataset.py | 60 +++++++-- espresso/data/asr_dataset.py | 86 +++++++++++-- espresso/data/asr_xent_dataset.py | 116 ++++++++++++++---- espresso/dump_posteriors.py | 1 + espresso/models/speech_lstm.py | 13 +- espresso/models/speech_lstm_encoder_model.py | 7 +- espresso/speech_recognize.py | 1 + espresso/tasks/speech_recognition.py | 17 ++- espresso/tasks/speech_recognition_hybrid.py | 18 ++- .../tools/generate_log_probs_for_decoding.py | 1 - espresso/tools/simple_greedy_decoder.py | 15 ++- examples/asr_librispeech/run.sh | 6 +- examples/asr_swbd/run.sh | 6 +- examples/asr_wsj/run.sh | 12 +- fairseq/sequence_generator.py | 8 +- 17 files changed, 401 insertions(+), 80 deletions(-) create mode 100644 espresso/data/asr_bucket_pad_length_dataset.py diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py index 1fadbbc76..152210569 100644 --- a/espresso/data/__init__.py +++ b/espresso/data/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .asr_bucket_pad_length_dataset import FeatBucketPadLengthDataset, TextBucketPadLengthDataset from .asr_chain_dataset import AsrChainDataset, NumeratorGraphDataset from .asr_dataset import AsrDataset from .asr_dictionary import AsrDictionary @@ -15,14 +16,16 @@ ) __all__ = [ - 'AliScpCachedDataset', - 'AsrChainDataset', - 'AsrDataset', - 'AsrDictionary', - 'AsrTextDataset', - 'AsrXentDataset', - 'FeatScpCachedDataset', - 'FeatScpDataset', - 'FeatScpInMemoryDataset', - 'NumeratorGraphDataset', + "AliScpCachedDataset", + "AsrChainDataset", + "AsrDataset", + "AsrDictionary", + "AsrTextDataset", + "AsrXentDataset", + "FeatBucketPadLengthDataset", + "FeatScpCachedDataset", + "FeatScpDataset", + "FeatScpInMemoryDataset", + "NumeratorGraphDataset", + "TextBucketPadLengthDataset", ] diff --git a/espresso/data/asr_bucket_pad_length_dataset.py b/espresso/data/asr_bucket_pad_length_dataset.py new file mode 100644 index 000000000..63d16af55 --- /dev/null +++ b/espresso/data/asr_bucket_pad_length_dataset.py @@ -0,0 +1,91 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn.functional as F + +from fairseq.data import BucketPadLengthDataset + + +class FeatBucketPadLengthDataset(BucketPadLengthDataset): + """ + Bucket and pad item lengths to the nearest bucket size for float tensors (features) + of shape `(length, feat_dim)`. This can be used to + reduce the number of unique batch shapes, which is important on TPUs since + each new batch shape requires a recompilation. + + Args: + dataset (FairseqDatset): dataset to bucket + sizes (List[int]): all item sizes + num_buckets (int): number of buckets to create + pad_idx (float, optional): padding value + left_pad (bool, optional): if True, pad on the left; otherwise right pad + """ + + def __init__( + self, + dataset, + sizes, + num_buckets, + pad_idx=None, + left_pad=False, + ): + super().__init__(dataset, sizes, num_buckets, pad_idx, left_pad) + if self.pad_idx is None: + self.pad_value = 0.0 + else: + self.pad_value = pad_idx + self.utt_ids = self.dataset.utt_ids + + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + num_pad = bucket_size - item.size(-1) + return F.pad( + item, + (0, 0, num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), + value=self.pad_value, + ) + + +class TextBucketPadLengthDataset(BucketPadLengthDataset): + """ + Bucket and pad item lengths to the nearest bucket size for :class:`AsrTextDataset`. + The main difference of this class from :class:`BucketPadLengthDataset` is that + here we only bucket the first element in the returned tuple of + :func:`AsrTextDataset.__getitem__`. This can be used to + reduce the number of unique batch shapes, which is important on TPUs since + each new batch shape requires a recompilation. + + Args: + dataset (FairseqDatset): dataset to bucket + sizes (List[int]): all item sizes + num_buckets (int): number of buckets to create + pad_idx (float, optional): padding value + left_pad (bool, optional): if True, pad on the left; otherwise right pad + """ + + def __init__( + self, + dataset, + sizes, + num_buckets, + pad_idx=None, + left_pad=False, + ): + super().__init__(dataset, sizes, num_buckets, pad_idx, left_pad) + self.utt_ids = self.dataset.utt_ids + + def __getitem__(self, index): + item = self.dataset[index][0] + bucket_size = self._bucketed_sizes[index] + num_pad = bucket_size - item.size(-1) + return ( + F.pad( + item, + (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), + value=self.pad_idx, + ), + self.dataset[index][1] + ) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 62b8f4cc7..1258982ca 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -16,10 +16,11 @@ import espresso.tools.utils as speech_utils + logger = logging.getLogger(__name__) -def collate(samples): +def collate(samples, src_bucketed=False): try: from pychain import ChainGraphBatch except ImportError: @@ -45,12 +46,17 @@ def merge(key): id = torch.LongTensor([s["id"] for s in samples]) src_frames = merge("source") # sort by descending source length - src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + if src_bucketed: + src_lengths = torch.IntTensor([ + s["source"].ne(0.0).any(dim=1).int().sum() for s in samples + ]) + else: + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()] src_frames = src_frames.index_select(0, sort_order) - ntokens = sum(s["source"].size(0) for s in samples) + ntokens = src_lengths.sum().item() target = None if samples[0].get("target", None) is not None: @@ -155,11 +161,14 @@ class AsrChainDataset(FairseqDataset): tgt_sizes (List[int], optional): target sizes (num of states in the numerator graph) text (torch.utils.data.Dataset, optional): text dataset to wrap shuffle (bool, optional): shuffle dataset elements before batching - (default: True) + (default: True). + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. """ def __init__( self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, shuffle=True, + num_buckets=0, ): self.src = src self.tgt = tgt @@ -183,6 +192,29 @@ def __init__( "{} remaining".format(num_removed, num_after_matching) ) + if num_buckets > 0: + from espresso.data import FeatBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=0.0, + left_pad=False, + ) + self.src_sizes = self.src.sizes + logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to FeatBucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) + for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" @@ -229,6 +261,9 @@ def _match_src_text(self): assert self.src.utt_ids == self.text.utt_ids return True + def get_batch_shapes(self): + return self.buckets + def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None text_item = self.text[index][1] if self.text is not None else None @@ -269,7 +304,7 @@ def collater(self, samples): numerator graphs - `text` (List[str]): list of original text """ - return collate(samples) + return collate(samples, src_bucketed=(self.buckets is not None)) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to @@ -288,9 +323,18 @@ def ordered_indices(self): indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) - if self.tgt_sizes is not None: - indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] - return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[ + np.argsort(self.tgt_sizes[indices], kind="mergesort") + ] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + else: + # sort by bucketed_num_tokens, which is padded_src_len + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") + ] @property def supports_prefetch(self): diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index db530fd76..549b2611e 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import numpy as np import torch @@ -11,9 +12,17 @@ import espresso.tools.utils as speech_utils +logger = logging.getLogger(__name__) + + def collate( - samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, + samples, + pad_idx, + eos_idx, + left_pad_source=True, + left_pad_target=False, input_feeding=True, + src_bucketed=False, ): if len(samples) == 0: return {} @@ -34,7 +43,12 @@ def merge(key, left_pad, move_eos_to_beginning=False): id = torch.LongTensor([s['id'] for s in samples]) src_frames = merge('source', left_pad=left_pad_source) # sort by descending source length - src_lengths = torch.IntTensor([s['source'].size(0) for s in samples]) + if src_bucketed: + src_lengths = torch.IntTensor([ + s['source'].ne(0.0).any(dim=1).int().sum() for s in samples + ]) + else: + src_lengths = torch.IntTensor([s['source'].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) utt_id = [samples[i]['utt_id'] for i in sort_order.numpy()] @@ -45,7 +59,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): if samples[0].get('target', None) is not None: target = merge('target', left_pad=left_pad_target) target = target.index_select(0, sort_order) - ntokens = sum(len(s['target']) for s in samples) + ntokens = sum(s['target'].ne(pad_idx).int().sum().item() for s in samples) if input_feeding: # we create a shifted version of targets for feeding the @@ -57,7 +71,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: - ntokens = sum(s['source'].size(0) for s in samples) + ntokens = src_lengths.sum().item() target_raw_text = None if samples[0].get('target_raw_text', None) is not None: @@ -95,9 +109,11 @@ class AsrDataset(FairseqDataset): left_pad_target (bool, optional): pad target tensors on the left side (default: False). shuffle (bool, optional): shuffle dataset elements before batching - (default: True) + (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. """ def __init__( @@ -105,6 +121,7 @@ def __init__( tgt=None, tgt_sizes=None, dictionary=None, left_pad_source=False, left_pad_target=False, shuffle=True, input_feeding=True, + num_buckets=0, ): self.src = src self.tgt = tgt @@ -118,6 +135,39 @@ def __init__( if self.tgt is not None: self._match_src_tgt() + if num_buckets > 0: + from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=0.0, + left_pad=False, + ) + self.src_sizes = self.src.sizes + logger.info('bucketing source lengths: {}'.format(list(self.src.buckets))) + if self.tgt is not None: + self.tgt = TextBucketPadLengthDataset( + self.tgt, + sizes=self.tgt_sizes, + num_buckets=num_buckets, + pad_idx=self.dictionary.pad(), + left_pad=False, + ) + self.tgt_sizes = self.tgt.sizes + logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets))) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to FeatBucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) + for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" @@ -141,6 +191,9 @@ def _match_src_tgt(self): self.tgt_sizes = np.array(self.tgt.sizes) assert self.src.utt_ids == self.tgt.utt_ids + def get_batch_shapes(self): + return self.buckets + def __getitem__(self, index): tgt_item = self.tgt[index][0] if self.tgt is not None else None raw_text_item = self.tgt[index][1] if self.tgt is not None else None @@ -190,9 +243,13 @@ def collater(self, samples): - `target_raw_text` (List[str]): list of original text """ return collate( - samples, pad_idx=self.dictionary.pad(), eos_idx=self.dictionary.eos(), - left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, + samples, + pad_idx=self.dictionary.pad(), + eos_idx=self.dictionary.eos(), + left_pad_source=self.left_pad_source, + left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, + src_bucketed=(self.buckets is not None), ) def num_tokens(self, index): @@ -212,9 +269,18 @@ def ordered_indices(self): indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) - if self.tgt_sizes is not None: - indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] - return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[ + np.argsort(self.tgt_sizes[indices], kind='mergesort') + ] + return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + else: + # sort by bucketed_num_tokens, which is padded_src_len + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind='mergesort') + ] @property def supports_prefetch(self): diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index 4afbaacb1..3b3057484 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -26,8 +26,16 @@ def collate( - samples, pad_idx, chunk_width, chunk_left_context, chunk_right_context, label_delay, - seed, epoch, random_chunking=True, + samples, + pad_idx, + chunk_width, + chunk_left_context, + chunk_right_context, + label_delay, + seed, + epoch, + src_bucketed=False, + random_chunking=True, ): if len(samples) == 0: return {} @@ -94,7 +102,12 @@ def chunking(src_item, tgt_item, tgt_start): else: s["source"] = src_item[: label_delay] - src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + if src_bucketed: + src_lengths = torch.IntTensor([ + s["source"].ne(0.0).any(dim=1).int().sum() for s in samples + ]) + else: + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) id = torch.LongTensor([s["id"] for s in samples]) utt_id = [s["utt_id"] for s in samples] src_frames = merge("source") @@ -102,9 +115,9 @@ def chunking(src_item, tgt_item, tgt_start): target = None if samples[0].get("target", None) is not None: target = merge("target") - ntokens = sum(len(s["target"]) for s in samples) + ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: - ntokens = sum(s["source"].size(0) for s in samples) + ntokens = src_lengths.sum().item() text = None if samples[0].get("text", None) is not None: @@ -135,7 +148,12 @@ def chunking(src_item, tgt_item, tgt_start): } return batch else: # sequential chunking, usually for chunk-wise test data - src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + if src_bucketed: + src_lengths = torch.IntTensor([ + s["source"].ne(0.0).any(dim=1).int().sum() for s in samples + ]) + else: + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) id = torch.LongTensor([s["id"] for s in samples]) utt_id = [s["utt_id"] for s in samples] ori_source = [s["source"] for s in samples] @@ -165,7 +183,7 @@ def chunking(src_item, tgt_item, tgt_start): target = merge("target") ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: - ntokens = sum(s["source"].size(0) for s in samples) + ntokens = src_lengths.sum().item() batch = { "id": id, @@ -317,21 +335,23 @@ class AsrXentDataset(FairseqDataset): tgt_vocab_size (int, optional): used for setting padding index text (torch.utils.data.Dataset, optional): text dataset to wrap shuffle (bool, optional): shuffle dataset elements before batching - (default: True) - seed (int, optional): random seed for generating a chunk from an utterance - chunk_width (int, optional): chunk width for chunk-wise training - chunk_left_context (int, optional): number of frames appended to the left of a chunk - chunk_right_context (int, optional): number of frames appended to the right of a chunk + (default: True). + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. + seed (int, optional): random seed for generating a chunk from an utterance. + chunk_width (int, optional): chunk width for chunk-wise training. + chunk_left_context (int, optional): number of frames appended to the left of a chunk. + chunk_right_context (int, optional): number of frames appended to the right of a chunk. label_delay (int, optional): offset of the alignments as prediction labels. Can be - useful in archs such as asymmetric convolution, unidirectional LSTM, etc + useful in archs such as asymmetric convolution, unidirectional LSTM, etc. random_chunking (bool, optional): wether do random chunking from utterance, or sequntially - obtain chunks within each utterance. True for train and False for valid/test data + obtain chunks within each utterance. True for train and False for valid/test data. """ def __init__( self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, - shuffle=True, seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, - label_delay=0, random_chunking=True, + shuffle=True, num_buckets=0, seed=1, chunk_width=None, + chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, ): self.src = src self.tgt = tgt @@ -373,6 +393,40 @@ def __init__( self.text.filter_and_reorder(indices) logger.warning("Done removal. {} examples remaining".format(len(indices))) + if num_buckets > 0: + from fairseq.data import BucketPadLengthDataset + from espresso.data import FeatBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=0.0, + left_pad=False, + ) + self.src_sizes = self.src.sizes + logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) + if self.tgt is not None: + self.tgt = BucketPadLengthDataset( + self.tgt, + sizes=self.tgt_sizes, + num_buckets=num_buckets, + pad_idx=self.dictionary.pad(), + left_pad=False, + ) + self.tgt_sizes = self.tgt.sizes + logger.info("bucketing target lengths: {}".format(list(self.tgt.buckets))) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to FeatBucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) + for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of their utt_ids. Removes those that are only present in one of them.""" @@ -421,6 +475,9 @@ def _match_src_text(self): assert self.src.utt_ids == self.text.utt_ids return True + def get_batch_shapes(self): + return self.buckets + def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None text_item = self.text[index][1] if self.text is not None else None @@ -463,9 +520,15 @@ def collater(self, samples): """ # pad_idx=-100 matches the default in criterions return collate( - samples, pad_idx=-100, chunk_width=self.chunk_width, - chunk_left_context=self.chunk_left_context, chunk_right_context=self.chunk_right_context, - label_delay=self.label_delay, seed=self.seed, epoch=self.epoch, + samples, + pad_idx=-100, + chunk_width=self.chunk_width, + chunk_left_context=self.chunk_left_context, + chunk_right_context=self.chunk_right_context, + label_delay=self.label_delay, + seed=self.seed, + epoch=self.epoch, + src_bucketed=(self.buckets is not None), random_chunking=self.random_chunking, ) @@ -488,9 +551,18 @@ def ordered_indices(self): indices = np.random.permutation(len(self)) else: indices = np.arange(len(self)) - if self.tgt_sizes is not None: - indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] - return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[ + np.argsort(self.tgt_sizes[indices], kind="mergesort") + ] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + else: + # sort by bucketed_num_tokens, which is padded_src_len + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") + ] @property def supports_prefetch(self): diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index a3ee532ac..9c609488f 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -57,6 +57,7 @@ def _main(args, output_file): utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, + suffix=getattr(args, "checkpoint_suffix", ""), ) # Load state prior for cross-entropy trained systems decoding diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index e0332784d..7d9fec93d 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -201,6 +201,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, residual=args.encoder_rnn_residual, + src_bucketed=(getattr(task.args, "num_batch_buckets", 0) > 0), max_source_positions=max_source_positions, ) decoder = SpeechLSTMDecoder( @@ -328,7 +329,7 @@ class SpeechLSTMEncoder(FairseqEncoder): def __init__( self, conv_layers_before=None, input_size=83, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - residual=False, left_pad=False, padding_value=0., + residual=False, left_pad=False, padding_value=0., src_bucketed=False, max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__(None) # no src dictionary @@ -351,6 +352,7 @@ def __init__( ]) self.left_pad = left_pad self.padding_value = padding_value + self.src_bucketed = src_bucketed self.output_units = hidden_size if bidirectional: @@ -408,7 +410,12 @@ def forward( prev_x = x # pack embedded source tokens into a PackedSequence packed_x = nn.utils.rnn.pack_padded_sequence( - x, src_lengths.data, enforce_sorted=enforce_sorted + x, + ( + src_lengths.data if not self.src_bucketed else + src_lengths.new_full(src_lengths.size(), x.size(0)) + ), + enforce_sorted=enforce_sorted ) # apply LSTM @@ -548,7 +555,7 @@ def forward( - the decoder's output of shape `(batch, tgt_len, vocab)` - attention weights of shape `(batch, tgt_len, src_len)` """ - if self.scheduled_sampling_rate_scheduler is not None: + if self.training and self.scheduled_sampling_rate_scheduler is not None: epoch = kwargs.get("epoch", 1) sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) if sampling_prob < 1.0: # apply scheduled sampling diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py index bb201d3f2..5c661aa24 100644 --- a/espresso/models/speech_lstm_encoder_model.py +++ b/espresso/models/speech_lstm_encoder_model.py @@ -106,6 +106,7 @@ def build_model(cls, args, task): dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, residual=args.encoder_rnn_residual, + src_bucketed=(getattr(task.args, "num_batch_buckets", 0) > 0), num_targets=getattr(task, "num_targets", None), # targets for encoder-only model chunk_width=getattr(task, "chunk_width", None), chunk_left_context=getattr(task, "chunk_left_context", 0), @@ -153,7 +154,7 @@ class SpeechChunkLSTMEncoder(SpeechLSTMEncoder): def __init__( self, conv_layers_before=None, input_size=83, hidden_size=512, num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - residual=False, left_pad=False, padding_value=0., + residual=False, left_pad=False, padding_value=0., src_bucketed=False, num_targets=None, chunk_width=20, chunk_left_context=0, training_stage=True, max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): @@ -161,7 +162,7 @@ def __init__( conv_layers_before=conv_layers_before, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout_in=dropout_in, dropout_out=dropout_in, bidirectional=bidirectional, residual=residual, left_pad=left_pad, - padding_value=padding_value, max_source_positions=max_source_positions, + padding_value=padding_value, src_bucketed=src_bucketed, max_source_positions=max_source_positions, ) receptive_field_radius = sum(conv.padding[0] for conv in conv_layers_before.convolutions) \ if conv_layers_before is not None else 0 @@ -242,5 +243,5 @@ def base_architecture(args): @register_model_architecture("speech_lstm_encoder_model", "speech_conv_lstm_encoder_model_wsj") -def encoder_conv_lstm_wsj(args): +def conv_lstm_encoder_wsj(args): base_architecture(args) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index ff8fbbcd8..1ea9240df 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -74,6 +74,7 @@ def _main(args, output_file): utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, + suffix=getattr(args, "checkpoint_suffix", ""), ) for i, m in enumerate(models): if hasattr(m, 'is_wordlm') and m.is_wordlm: diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 0b301f424..f2c918e32 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -12,7 +12,7 @@ import torch from fairseq import search, utils -from fairseq.data import ConcatDataset +from fairseq.data import BaseWrapperDataset, ConcatDataset from fairseq.logging import metrics from fairseq.tasks import FairseqTask, register_task @@ -30,6 +30,7 @@ def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, + num_buckets=0, seed=1, specaugment_config=None, ): """ @@ -112,6 +113,7 @@ def get_asr_dataset_from_json( tgt_dict, left_pad_source=False, left_pad_target=False, + num_buckets=num_buckets, ) @@ -159,6 +161,10 @@ def add_args(parser): help="max number of tokens in the target sequence") parser.add_argument("--upsample-primary", default=1, type=int, help="amount to upsample primary dataset") + parser.add_argument("--num-batch-buckets", default=0, type=int, metavar="N", + help="if >0, then bucket source and target lengths into N " + "buckets and pad accordingly; this is useful on TPUs " + "to minimize the number of compilations") parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", help="feature input channels") parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", @@ -233,13 +239,18 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): data_path, split, self.tgt_dict, combine=combine, upsample_primary=self.args.upsample_primary, + num_buckets=self.args.num_batch_buckets, seed=self.args.seed, specaugment_config=self.specaugment_config, ) src_dataset = self.datasets[split].src - self.feat_dim = src_dataset.feat_dim if not isinstance(src_dataset, ConcatDataset) \ - else src_dataset.datasets[0].feat_dim + if isinstance(src_dataset, ConcatDataset): + self.feat_dim = src_dataset.datasets[0].feat_dim + elif isinstance(src_dataset, BaseWrapperDataset): + self.feat_dim = src_dataset.dataset.feat_dim + else: + self.feat_dim = src_dataset.feat_dim # update the counts of and in tgt_dict with training data if split == "train": diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index b17261682..68011d258 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -12,7 +12,7 @@ import torch from fairseq import utils -from fairseq.data import ConcatDataset +from fairseq.data import BaseWrapperDataset, ConcatDataset from fairseq.tasks import FairseqTask, register_task @@ -38,6 +38,7 @@ def get_asr_dataset_from_json( data_path, split, dictionary, combine, upsample_primary, + num_buckets=0, lf_mmi=True, seed=1, specaugment_config=None, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, @@ -145,12 +146,14 @@ def get_asr_dataset_from_json( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, text=text_dataset, + num_buckets=num_buckets, ) else: return AsrXentDataset( src_dataset, src_dataset.sizes, tgt_dataset, tgt_dataset_sizes, text=text_dataset, + num_buckets=num_buckets, seed=seed, chunk_width=chunk_width, chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), @@ -202,6 +205,10 @@ def add_args(parser): help="max number of tokens in the target sequence") parser.add_argument("--upsample-primary", default=1, type=int, help="amount to upsample primary dataset") + parser.add_argument("--num-batch-buckets", default=0, type=int, metavar="N", + help="if >0, then bucket source and target lengths into N " + "buckets and pad accordingly; this is useful on TPUs " + "to minimize the number of compilations") parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", help="feature input channels") parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", @@ -313,6 +320,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): data_path, split, self.dictionary, combine=combine, upsample_primary=self.args.upsample_primary, + num_buckets=self.args.num_batch_buckets, lf_mmi=(self.args.criterion == "lattice_free_mmi"), seed=self.args.seed, specaugment_config=self.specaugment_config, chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, @@ -321,8 +329,12 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): ) src_dataset = self.datasets[split].src - self.feat_dim = src_dataset.feat_dim if not isinstance(src_dataset, ConcatDataset) \ - else src_dataset.datasets[0].feat_dim + if isinstance(src_dataset, ConcatDataset): + self.feat_dim = src_dataset.datasets[0].feat_dim + elif isinstance(src_dataset, BaseWrapperDataset): + self.feat_dim = src_dataset.dataset.feat_dim + else: + self.feat_dim = src_dataset.feat_dim def build_generator(self, models, args): if args.score_reference: diff --git a/espresso/tools/generate_log_probs_for_decoding.py b/espresso/tools/generate_log_probs_for_decoding.py index 1b1a44dca..fd6022aa2 100644 --- a/espresso/tools/generate_log_probs_for_decoding.py +++ b/espresso/tools/generate_log_probs_for_decoding.py @@ -48,7 +48,6 @@ def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): models (List[~fairseq.models.FairseqModel]): ensemble of models sample (dict): batch """ - self.model.reset_incremental_state() return self._generate(sample, **kwargs) def _generate(self, sample: Dict[str, Dict[str, Tensor]], **kwargs): diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index 3602bc348..5cdb9815b 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, List, Optional import numpy as np @@ -68,11 +68,17 @@ def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): bos_token (int, optional): beginning of sentence token (default: self.eos) """ - self.model.reset_incremental_state() return self._decode(sample, **kwargs) @torch.no_grad() def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) net_input = sample["net_input"] src_tokens = net_input["src_tokens"] input_size = src_tokens.size() @@ -111,7 +117,10 @@ def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] attn = attn[:, :, :step + 1] break log_probs, avg_attn_scores = self.model.forward_decoder( - tokens[:, :step + 1], encoder_outs, temperature=self.temperature, + tokens[:, : step + 1], + encoder_outs, + incremental_states, + temperature=self.temperature, ) tokens[:, step + 1] = log_probs.argmax(-1) if step > 0: # deal with finished predictions diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index b55e71fe4..37d5e789f 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -148,7 +148,7 @@ if [ ${stage} -le 4 ]; then for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -168,7 +168,7 @@ if [ ${stage} -le 5 ]; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((16000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 32000 --max-sentences 1024 --curriculum 1 \ @@ -189,7 +189,7 @@ if [ ${stage} -le 6 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log - python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 485a67ffd..e477abac4 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -186,7 +186,7 @@ if [ $stage -le 3 ]; then test_paths= && for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -206,7 +206,7 @@ if [ $stage -le 4 ]; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((1000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 1024 \ @@ -227,7 +227,7 @@ if [ $stage -le 5 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log - python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index d2f2925ee..9d7931aba 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -153,7 +153,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing char text..." mkdir -p $lmdatadir/log ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 30 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -163,7 +163,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing word text..." mkdir -p $wordlmdatadir/log ${decode_cmd} $wordlmdatadir/log/preprocess.log \ - python3 ../../preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ --workers 30 --srcdict $wordlmdict --only-source \ --trainpref $wordlmdatadir/train \ --validpref $wordlmdatadir/$valid_set \ @@ -184,7 +184,7 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 128 \ @@ -201,7 +201,7 @@ if [ ${stage} -le 5 ] && ! $use_wordlm; then echo "Stage 5: char LM Evaluation" for gen_subset in valid test; do log_file=$lmdir/log/evaluation_$gen_subset.log - python3 ../../eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ --max-tokens 192000 --max-sentences 256 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -214,7 +214,7 @@ if [ ${stage} -le 6 ] && $use_wordlm; then mkdir -p $wordlmdir/log log_file=$wordlmdir/log/train.log [ -f $wordlmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../train.py $wordlmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $wordlmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $wordlmdict \ --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 6400 --max-sentences 256 \ @@ -232,7 +232,7 @@ if [ ${stage} -le 7 ] && $use_wordlm; then echo "Stage 7: word LM Evaluation" for gen_subset in valid test; do log_file=$wordlmdir/log/evaluation_$gen_subset.log - python3 ../../eval_lm.py $wordlmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $wordlmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $wordlmdict --gen-subset $gen_subset \ --max-tokens 12800 --max-sentences 512 --sample-break-mode eos \ --path $wordlmdir/$lm_checkpoint 2>&1 | tee $log_file diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 972c80d0d..fc2a15f2c 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -1040,7 +1040,11 @@ def forward_encoder(self, net_input: Dict[str, Tensor]): @torch.jit.export def forward_decoder( - self, tokens, encoder_outs: List[EncoderOut], temperature: float = 1.0 + self, + tokens, + encoder_outs: List[EncoderOut], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, ): log_probs = [] avg_attn: Optional[Tensor] = None @@ -1053,7 +1057,7 @@ def forward_decoder( decoder_out = model.decoder.forward( tokens, encoder_out=encoder_out, - incremental_state=self.incremental_states[i], + incremental_state=incremental_states[i], ) else: decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) From e42ec43f36dfafa039644c9f2868b3e428844a4e Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Thu, 2 Jul 2020 13:55:20 -0400 Subject: [PATCH 087/119] Update Transformer models (#31) * update transformer * initial recipe * fix transformer * add encoder positional embeddings * add more recipes --- espresso/criterions/cross_entropy_v2.py | 5 +- .../label_smoothed_cross_entropy_v2.py | 5 +- espresso/models/speech_transformer.py | 548 ++++++++++++++---- .../speech_transformer_encoder_model.py | 301 ++++++++++ espresso/speech_recognize.py | 2 +- examples/asr_librispeech/run.sh | 41 +- examples/asr_swbd/run.sh | 41 +- examples/asr_wsj/run.sh | 31 +- 8 files changed, 820 insertions(+), 154 deletions(-) create mode 100644 espresso/models/speech_transformer_encoder_model.py diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py index 1888d69b9..39f11df00 100644 --- a/espresso/criterions/cross_entropy_v2.py +++ b/espresso/criterions/cross_entropy_v2.py @@ -26,6 +26,7 @@ def __init__(self, task, sentence_avg, print_interval): self.dictionary = task.target_dictionary self.print_interval = print_interval self.epoch = 1 + self.prev_num_updates = -1 @staticmethod def add_args(parser): @@ -59,8 +60,10 @@ def forward(self, model, sample, reduce=True): if ( hasattr(model, "num_updates") and model.training and model.num_updates // self.print_interval > - (model.num_updates - 1) // self.print_interval + (model.num_updates - 1) // self.print_interval and + model.num_updates != self.prev_num_updates ): # print a randomly sampled result every print_interval updates + self.prev_num_updates = model.num_updates target = model.get_targets(sample, net_output) pred = lprobs.argmax(-1).cpu() # bsz x len assert pred.size() == target.size() diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index 383a4a1f7..8a522c0e1 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -94,6 +94,7 @@ def __init__( self.unigram_tensor = torch.FloatTensor(self.dictionary.count).unsqueeze(-1) self.unigram_tensor += unigram_pseudo_count # for further backoff self.unigram_tensor.div_(self.unigram_tensor.sum()) + self.prev_num_updates = -1 @staticmethod def add_args(parser): @@ -137,8 +138,10 @@ def forward(self, model, sample, reduce=True): if ( hasattr(model, "num_updates") and model.training and model.num_updates // self.print_interval > - (model.num_updates - 1) // self.print_interval + (model.num_updates - 1) // self.print_interval and + model.num_updates != self.prev_num_updates ): # print a randomly sampled result every print_interval updates + self.prev_num_updates = model.num_updates target = model.get_targets(sample, net_output) pred = lprobs.argmax(-1).cpu() # bsz x len assert pred.size() == target.size() diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index cf4f4ddcc..de2875b5b 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -4,32 +4,41 @@ # LICENSE file in the root directory of this source tree. import logging +from typing import Any, Dict, List, Optional import torch +from torch import Tensor import torch.nn as nn import torch.nn.functional as F -from fairseq import utils +from fairseq import options from fairseq.models import ( register_model, register_model_architecture, ) +from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( Embedding, Linear, TransformerModel, TransformerEncoder, TransformerDecoder, - TransformerEncoderLayer, ) -from fairseq.modules import LayerNorm +from fairseq.modules import ( + LayerDropModuleList, + LayerNorm, + PositionalEmbedding, + TransformerDecoderLayer, +) +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ from espresso.models.speech_lstm import ConvBNReLU +from espresso.tools.scheduled_sampling_rate_scheduler import ScheduledSamplingRateScheduler import espresso.tools.utils as speech_utils -DEFAULT_MAX_SOURCE_POSITIONS = 9999 -DEFAULT_MAX_TARGET_POSITIONS = 999 +DEFAULT_MAX_SOURCE_POSITIONS = 10240 +DEFAULT_MAX_TARGET_POSITIONS = 1024 logger = logging.getLogger(__name__) @@ -39,11 +48,12 @@ class SpeechTransformerModel(TransformerModel): """ Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) - `_. + `_. It adds 2D convolutions before + transformer layers in the encoder to process speech input. Args: - encoder (TransformerEncoder): the encoder - decoder (TransformerDecoder): the decoder + encoder (SpeechTransformerEncoder): the encoder + decoder (SpeechTransformerDecoder): the decoder The Transformer model provides the following named architectures and command-line arguments: @@ -66,12 +76,28 @@ def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off TransformerModel.add_args(parser) - parser.add_argument('--encoder-conv-channels', type=str, metavar='EXPR', - help='list of encoder convolution\'s out channels') - parser.add_argument('--encoder-conv-kernel-sizes', type=str, metavar='EXPR', - help='list of encoder convolution\'s kernel sizes') - parser.add_argument('--encoder-conv-strides', type=str, metavar='EXPR', - help='list of encoder convolution\'s strides') + parser.add_argument("--encoder-conv-channels", type=str, metavar="EXPR", + help="list of encoder convolution\'s out channels") + parser.add_argument("--encoder-conv-kernel-sizes", type=str, metavar="EXPR", + help="list of encoder convolution\'s kernel sizes") + parser.add_argument("--encoder-conv-strides", type=str, metavar="EXPR", + help="list of encoder convolution\'s strides") + parser.add_argument("--encoder-transformer-context", type=str, metavar="EXPR", + help="left/right context for time-restricted self-attention; " + "can be None or a tuple of two non-negative integers/None") + parser.add_argument("--decoder-input-dim", type=int, metavar="N", + help="decoder input dimension (extra linear layer " + "if different from decoder embed dim)") + + # Scheduled sampling options + parser.add_argument("--scheduled-sampling-probs", type=lambda p: options.eval_str_list(p), + metavar="P_1,P_2,...,P_N", default=[1.0], + help="scheduled sampling probabilities of sampling the truth " + "labels for N epochs starting from --start-schedule-sampling-epoch; " + "all later epochs using P_N") + parser.add_argument("--start-scheduled-sampling-epoch", type=int, + metavar="N", default=1, + help="start scheduled sampling from the specified epoch") # fmt: on @classmethod @@ -81,31 +107,26 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_architecture(args) - if not hasattr(args, 'max_source_positions'): + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS - if not hasattr(args, 'max_target_positions'): + if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS - dict = task.target_dictionary - - def build_embedding(dictionary, embed_dim, path=None): - num_embeddings = len(dictionary) - padding_idx = dictionary.pad() - emb = Embedding(num_embeddings, embed_dim, padding_idx) - # if provided, load from preloaded dictionaries - if path: - embed_dict = utils.parse_embedding(path) - utils.load_embedding(embed_dict, dictionary, emb) - return emb + tgt_dict = task.target_dictionary - decoder_embed_tokens = build_embedding( - dict, args.decoder_embed_dim, args.decoder_embed_path + decoder_embed_tokens = cls.build_embedding( + args, tgt_dict, args.decoder_input_dim, args.decoder_embed_path ) out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) strides = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) - logger.info('input feature dimension: {}, channels: {}'.format(task.feat_dim, task.feat_in_channels)) + logger.info("input feature dimension: {}, channels: {}".format(task.feat_dim, task.feat_in_channels)) assert task.feat_dim % task.feat_in_channels == 0 conv_layers = ConvBNReLU( out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, @@ -120,35 +141,99 @@ def build_embedding(dictionary, embed_dim, path=None): else: assert isinstance(stride, int) s = stride - transformer_encoder_input_size = \ - (transformer_encoder_input_size + s - 1) // s + transformer_encoder_input_size = (transformer_encoder_input_size + s - 1) // s transformer_encoder_input_size *= out_channels[-1] + else: + transformer_encoder_input_size = task.feat_dim + + encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( + args.encoder_transformer_context, type=int, + ) + if encoder_transformer_context is not None: + assert len(encoder_transformer_context) == 2 + for i in range(2): + assert ( + encoder_transformer_context[i] is None + or ( + isinstance(encoder_transformer_context[i], int) + and encoder_transformer_context[i] >= 0 + ) + ) + + scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler( + args.scheduled_sampling_probs, args.start_scheduled_sampling_epoch, + ) encoder = cls.build_encoder( args, conv_layers_before=conv_layers, input_size=transformer_encoder_input_size, + transformer_context=encoder_transformer_context, ) - decoder = cls.build_decoder(args, dict, decoder_embed_tokens) - return cls(encoder, decoder) + decoder = cls.build_decoder( + args, tgt_dict, decoder_embed_tokens, + scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, + ) + return cls(args, encoder, decoder) def set_num_updates(self, num_updates): self.num_updates = num_updates super().set_num_updates(num_updates) @classmethod - def build_encoder(cls, args, conv_layers_before=None, input_size=83): + def build_encoder(cls, args, conv_layers_before=None, input_size=83, transformer_context=None): return SpeechTransformerEncoder( args, conv_layers_before=conv_layers_before, input_size=input_size, + transformer_context=transformer_context, ) @classmethod - def build_decoder(cls, args, dict, embed_tokens): - return SpeechTransformerDecoder(args, dict, embed_tokens) + def build_decoder(cls, args, tgt_dict, embed_tokens, scheduled_sampling_rate_scheduler=None): + return SpeechTransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, + ) + + # TorchScript doesn't support optional arguments with variable length (**kwargs). + # Current workaround is to add union of all arguments in child classes. + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + return_all_hiddens: bool = True, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + epoch=1, + ): + """ + Run the forward pass for an encoder-decoder model. + + Copied from the base class, but without ``**kwargs``, + which are not supported by TorchScript. + """ + encoder_out = self.encoder( + src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens + ) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + epoch=epoch, + ) + return decoder_out class SpeechTransformerEncoder(TransformerEncoder): """ - Transformer encoder consisting of *args.encoder_layers* layers. Each layer - is a :class:`TransformerEncoderLayer`. + Transformer encoder consisting of 2D convolution layers and + *args.encoder_layers* layers. Each layer is a :class:`TransformerEncoderLayer`. Args: args (argparse.Namespace): parsed command-line arguments @@ -158,46 +243,116 @@ class SpeechTransformerEncoder(TransformerEncoder): before being projected to args.encoder_embed_dim """ - def __init__(self, args, conv_layers_before=None, input_size=83): + def __init__(self, args, conv_layers_before=None, input_size=83, transformer_context=None): super(TransformerEncoder, self).__init__(None) # no src dictionary - self.register_buffer('version', torch.Tensor([3])) + self.register_buffer("version", torch.Tensor([3])) self.dropout = args.dropout + self.encoder_layerdrop = args.encoder_layerdrop - self.conv_layers_before = conv_layers_before - self.fc0 = Linear(input_size, args.encoder_embed_dim) \ - if input_size != args.encoder_embed_dim else None + embed_dim = args.encoder_embed_dim self.max_source_positions = args.max_source_positions - self.layers = nn.ModuleList([]) - self.layers.extend([ - TransformerEncoderLayer(args) - for i in range(args.encoder_layers) - ]) + self.embed_positions = None + + self.conv_layers_before = conv_layers_before + self.fc0 = Linear(input_size, embed_dim) if input_size != embed_dim else None + + self.embed_positions = ( + PositionalEmbedding( + self.output_lengths(args.max_source_positions), + embed_dim, + 0, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + if not args.adaptive_input and args.quant_noise_pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + self.quant_noise = None + + if self.encoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.encoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [self.build_encoder_layer(args) for i in range(args.encoder_layers)] + ) + self.num_layers = len(self.layers) if args.encoder_normalize_before: - self.layer_norm = LayerNorm(args.encoder_embed_dim) + self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + self.transformer_context = transformer_context def output_lengths(self, in_lengths): return in_lengths if self.conv_layers_before is None \ else self.conv_layers_before.output_lengths(in_lengths) - def forward(self, src_tokens, src_lengths): + def get_attn_mask(self, in_lengths): + """ + Create attention mask according to sequence lengths and transformer context. + + Args: + in_lengths (LongTensor): lengths of each input sequence of shape `(batch)` + + Returns: + attn_mask (ByteTensor|BoolTensor, optional): self-attention mask of shape + `(tgt_len, src_len)`, where `tgt_len` is the length of output and `src_len` + is the length of input, though here both are equal to `seq_len`. + `attn_mask[tgt_i, src_j] = 1` means that when calculating the + embedding for `tgt_i`, we exclude (mask out) `src_j`. + """ + if ( + self.transformer_context is None + or (self.transformer_context[0] is None and self.transformer_context[1] is None) + ): + return None + max_len = in_lengths.data.max() + all_ones = in_lengths.ones([max_len, max_len], dtype=torch.bool) + # at this point left and right context cannot be both None + if self.transformer_context[0] is None: # mask is a triu matrix + return all_ones.triu(self.transformer_context[1] + 1) + if self.transformer_context[1] is None: # mask is a tril matrix + return all_ones.tril(-self.transformer_context[0] - 1) + return ( + all_ones.triu(self.transformer_context[1] + 1) | all_ones.tril(-self.transformer_context[0] - 1) + ) + + def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). Returns: - dict: + namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. """ if self.conv_layers_before is not None: x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, src_lengths) @@ -205,98 +360,263 @@ def forward(self, src_tokens, src_lengths): x, encoder_padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) - if not encoder_padding_mask.any(): - encoder_padding_mask = None - x = F.dropout(x, p=self.dropout, training=self.training) if self.fc0 is not None: x = self.fc0(x) + if self.embed_positions is not None: + # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings + x = x + self.embed_positions((~encoder_padding_mask).int()) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) + elif self.embed_positions is not None: + # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings + x = x + self.embed_positions((~encoder_padding_mask).int()) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + if not encoder_padding_mask.any(): + encoder_padding_mask = None # B x T x C -> T x B x C x = x.transpose(0, 1) + attn_mask = self.get_attn_mask(src_lengths) + + encoder_states = [] if return_all_hiddens else None + # encoder layers for layer in self.layers: - x = layer(x, encoder_padding_mask) + x = layer(x, encoder_padding_mask, attn_mask=attn_mask) + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) - if self.layer_norm: + if self.layer_norm is not None: x = self.layer_norm(x) - return { - 'encoder_out': x, # T x B x C - 'encoder_padding_mask': encoder_padding_mask, # B x T - } + return EncoderOut( + encoder_out=x, # T x B x C + encoder_padding_mask=encoder_padding_mask, # B x T + encoder_embedding=None, + encoder_states=encoder_states, # List[T x B x C] + src_tokens=None, + src_lengths=None, + ) def max_positions(self): """Maximum input length supported by the encoder.""" return self.max_source_positions - def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" - for i in range(len(self.layers)): - # update layer norms - self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) - version_key = '{}.version'.format(name) - if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: - # earlier checkpoints did not normalize after the stack of layers - self.layer_norm = None - self.normalize = False - state_dict[version_key] = torch.Tensor([1]) - return state_dict +class SpeechTransformerDecoder(TransformerDecoder): + def __init__( + self, args, dictionary, embed_tokens, no_encoder_attn=False, + scheduled_sampling_rate_scheduler=None, + ): + super().__init__(args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn) + self.scheduled_sampling_rate_scheduler = scheduled_sampling_rate_scheduler + for layer in self.layers: + if isinstance(layer, TransformerDecoderLayer): + layer.need_attn = False # make validation fast + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + **kwargs, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for input feeding/teacher forcing + encoder_out (EncoderOut, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + if self.training and alignment_layer is None: # no attention tensors during training to save memory + alignment_layer = self.num_layers # can be any value no less than this + if self.training and self.scheduled_sampling_rate_scheduler is not None: + epoch = kwargs.get("epoch", 1) + sampling_prob = self.scheduled_sampling_rate_scheduler.step(epoch) + if sampling_prob < 1.0: # apply scheduled sampling + assert not features_only + return self._forward_with_scheduled_sampling( + prev_output_tokens, sampling_prob, encoder_out=encoder_out, + incremental_state={}, # use empty dict to preserve forward state + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + ) + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + if not features_only: + x = self.output_layer(x) + return x, extra + + def _forward_with_scheduled_sampling( + self, + prev_output_tokens, + sampling_prob, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + bsz, seqlen = prev_output_tokens.size() + outs = [] + pred = None + for step in range(seqlen): + if step > 0: + sampling_mask = torch.rand( + [bsz, 1], device=prev_output_tokens.device, + ).lt(sampling_prob) + feed_tokens = torch.where( + sampling_mask, prev_output_tokens[:, step:step + 1], pred, + ) + else: + feed_tokens = prev_output_tokens[:, step:step + 1] # B x 1 + x, _ = self.extract_features( + feed_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + x = self.output_layer(x) # B x 1 x V + outs.append(x) + pred = x.argmax(-1) # B x 1 + x = torch.cat(outs, dim=1) # B x T x V + return x, None -class SpeechTransformerDecoder(TransformerDecoder): def masked_copy_incremental_state(self, incremental_state, another_cached_state, mask): - pass + raise NotImplementedError -@register_model_architecture('speech_transformer', 'speech_transformer') +@register_model_architecture("speech_transformer", "speech_transformer") def base_architecture(args): args.encoder_conv_channels = getattr( - args, 'encoder_conv_channels', '[64, 64, 128, 128]', + args, "encoder_conv_channels", "[64, 64, 128, 128]", ) args.encoder_conv_kernel_sizes = getattr( - args, 'encoder_conv_kernel_sizes', '[(3, 3), (3, 3), (3, 3), (3, 3)]', + args, "encoder_conv_kernel_sizes", "[(3, 3), (3, 3), (3, 3), (3, 3)]", ) args.encoder_conv_strides = getattr( - args, 'encoder_conv_strides', '[(1, 1), (2, 2), (1, 1), (2, 2)]', + args, "encoder_conv_strides", "[(1, 1), (2, 2), (1, 1), (2, 2)]", ) - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512) - args.encoder_layers = getattr(args, 'encoder_layers', 6) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) - args.decoder_layers = getattr(args, 'decoder_layers', 6) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) - args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) - args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) - args.attention_dropout = getattr(args, 'attention_dropout', 0.) - args.activation_dropout = getattr(args, 'activation_dropout', 0.) - args.activation_fn = getattr(args, 'activation_fn', 'relu') - args.dropout = getattr(args, 'dropout', 0.1) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) - args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) - args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) - args.adaptive_input = getattr(args, 'adaptive_input', False) - - args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) - - -@register_model_architecture('speech_transformer', 'speech_transformer_librispeech') + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.encoder_transformer_context = getattr(args, "encoder_transformer_context", None) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.2) + args.activation_dropout = getattr(args, "activation_dropout", 0.2) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.2) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.no_cross_attention = getattr(args, "no_cross_attention", False) + args.cross_self_attention = getattr(args, "cross_self_attention", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + + +@register_model_architecture("speech_transformer", "speech_transformer_wsj") +def speech_transformer_wsj(args): + base_architecture(args) + + +@register_model_architecture("speech_transformer", "speech_transformer_librispeech") def speech_transformer_librispeech(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) - args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512) - args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 1) - args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 1) - args.dropout = getattr(args, 'dropout', 0.3) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_transformer_context = getattr(args, "encoder_transformer_context", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + base_architecture(args) + + +@register_model_architecture("speech_transformer", "speech_transformer_swbd") +def speech_transformer_swbd(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_transformer_context = getattr(args, "encoder_transformer_context", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.attention_dropout = getattr(args, "attention_dropout", 0.25) + args.activation_dropout = getattr(args, "activation_dropout", 0.25) + args.dropout = getattr(args, "dropout", 0.25) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) base_architecture(args) diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py new file mode 100644 index 000000000..b4c49edf6 --- /dev/null +++ b/espresso/models/speech_transformer_encoder_model.py @@ -0,0 +1,301 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.models import ( + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.transformer import Linear + +from espresso.models.speech_lstm import ConvBNReLU +from espresso.models.speech_transformer import SpeechTransformerEncoder +import espresso.tools.utils as speech_utils + + +DEFAULT_MAX_SOURCE_POSITIONS = 10240 + + +logger = logging.getLogger(__name__) + + +@register_model("speech_transformer_encoder_model") +class SpeechTransformerEncoderModel(FairseqEncoderModel): + def __init__(self, args, encoder, state_prior: Optional[torch.FloatTensor] = None): + super().__init__(encoder) + self.args = args + self.state_prior = state_prior + self.num_updates = 0 + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use") + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--encoder-conv-channels", type=str, metavar="EXPR", + help="list of encoder convolution's out channels") + parser.add_argument("--encoder-conv-kernel-sizes", type=str, metavar="EXPR", + help="list of encoder convolution's kernel sizes") + parser.add_argument("--encoder-conv-strides", type=str, metavar="EXPR", + help="list of encoder convolution's strides") + parser.add_argument("--attention-dropout", type=float, metavar="D", + help="dropout probability for attention weights") + parser.add_argument("--activation-dropout", "--relu-dropout", type=float, metavar="D", + help="dropout probability after activation in FFN.") + parser.add_argument("--encoder-ffn-embed-dim", type=int, metavar="N", + help="encoder embedding dimension for FFN") + parser.add_argument("--encoder-layers", type=int, metavar="N", + help="num encoder layers") + parser.add_argument("--encoder-attention-heads", type=int, metavar="N", + help="num encoder attention heads") + parser.add_argument("--encoder-normalize-before", action="store_true", + help="apply layernorm before each encoder block") + parser.add_argument("--encoder-transformer-context", type=str, metavar="EXPR", + help="left/right context for time-restricted self-attention; " + "can be None or a tuple of two non-negative integers/None") + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + parser.add_argument("--encoder-layerdrop", type=float, metavar="D", default=0, + help="LayerDrop probability for encoder") + parser.add_argument("--encoder-layers-to-keep", default=None, + help="which layers to *keep* when pruning as a comma-separated list") + # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) + parser.add_argument("--quant-noise-pq", type=float, metavar="D", default=0, + help="iterative PQ quantization noise at training time") + parser.add_argument("--quant-noise-pq-block-size", type=int, metavar="D", default=8, + help="block size of quantization noise at training time") + parser.add_argument("--quant-noise-scalar", type=float, metavar="D", default=0, + help="scalar quantization noise and scalar quantization at training time") + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + + if getattr(args, "max_source_positions", None) is None: + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + + out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) + kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) + strides = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_strides, type=int) + logger.info("input feature dimension: {}, channels: {}".format(task.feat_dim, task.feat_in_channels)) + assert task.feat_dim % task.feat_in_channels == 0 + conv_layers = ConvBNReLU( + out_channels, kernel_sizes, strides, in_channels=task.feat_in_channels, + ) if out_channels is not None else None + + transformer_encoder_input_size = task.feat_dim // task.feat_in_channels + if conv_layers is not None: + for stride in strides: + if isinstance(stride, (list, tuple)): + assert len(stride) > 0 + s = stride[1] if len(stride) > 1 else stride[0] + else: + assert isinstance(stride, int) + s = stride + transformer_encoder_input_size = (transformer_encoder_input_size + s - 1) // s + transformer_encoder_input_size *= out_channels[-1] + else: + transformer_encoder_input_size = task.feat_dim + + encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( + args.encoder_transformer_context, type=int, + ) + if encoder_transformer_context is not None: + assert len(encoder_transformer_context) == 2 + for i in range(2): + assert ( + encoder_transformer_context[i] is None + or ( + isinstance(encoder_transformer_context[i], int) + and encoder_transformer_context[i] >= 0 + ) + ) + + encoder = cls.build_encoder( + args, + conv_layers_before=conv_layers, + input_size=transformer_encoder_input_size, + transformer_context=encoder_transformer_context, + num_targets=getattr(task, "num_targets", None), # targets for encoder-only model + chunk_width=getattr(task, "chunk_width", None), + chunk_left_context=getattr(task, "chunk_left_context", 0), + training_stage=getattr(task, "training_stage", True), + ) + return cls(args, encoder, state_prior=getattr(task, "initial_state_prior", None)) + + def set_num_updates(self, num_updates): + self.num_updates = num_updates + super().set_num_updates(num_updates) + + @classmethod + def build_encoder(cls, args, conv_layers_before=None, input_size=83, transformer_context=None, + num_targets=None, chunk_width=None, chunk_left_context=0, training_stage=True, + ): + return SpeechChunkTransformerEncoder( + args, + conv_layers_before=conv_layers_before, + input_size=input_size, + transformer_context=transformer_context, + num_targets=num_targets, + chunk_width=chunk_width, + chunk_left_context=chunk_left_context, + training_stage=training_stage, + ) + + def output_lengths(self, in_lengths): + return self.encoder.output_lengths(in_lengths) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + encoder_out = net_output.encoder_out + if torch.is_tensor(encoder_out): + logits = encoder_out.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def update_state_prior(self, new_state_prior, factor=0.1): + assert self.state_prior is not None + self.state_prior = self.state_prior.to(new_state_prior) + self.state_prior = (1. - factor) * self.state_prior + factor * new_state_prior + self.state_prior = self.state_prior / self.state_prior.sum() # re-normalize + + def state_dict(self): + state_dict = super().state_dict() + state_dict["state_prior"] = self.state_prior + return state_dict + + def load_state_dict(self, state_dict, strict=True, args=None): + state_dict_subset = state_dict.copy() + self.state_prior = state_dict.get("state_prior", None) + if "state_prior" in state_dict: + self.state_prior = state_dict["state_prior"] + del state_dict_subset["state_prior"] + super().load_state_dict(state_dict_subset, strict=strict, args=args) + + +class SpeechChunkTransformerEncoder(SpeechTransformerEncoder): + """Transformer encoder for speech (possibly chunk) data.""" + def __init__( + self, args, conv_layers_before=None, input_size=83, transformer_context=None, + num_targets=None, chunk_width=None, chunk_left_context=0, training_stage=True, + ): + super().__init__( + args, conv_layers_before=conv_layers_before, input_size=input_size, + transformer_context=transformer_context, + ) + receptive_field_radius = sum(conv.padding[0] for conv in conv_layers_before.convolutions) \ + if conv_layers_before is not None else 0 + assert chunk_width is None or chunk_width > 0 + assert (conv_layers_before is None and chunk_left_context >= 0) or \ + (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 + self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ + if chunk_width is not None else None + self.training_stage = training_stage + + # only for encoder-only model + self.fc_out = Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout) \ + if num_targets is not None else None + + def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + + Returns: + namedtuple: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + out = super().forward(src_tokens, src_lengths, return_all_hiddens=return_all_hiddens) + x, x_lengths = out.encoder_out, out.src_lengths + + # determine which output frame to select for loss evaluation/test, assuming + # all examples in a batch are of the same length for chunk-wise training/test + if ( + self.out_chunk_end is not None + and (self.training or not self.training_stage) + ): + x = x[self.out_chunk_begin: self.out_chunk_end] # T x B x C -> W x B x C + x_lengths = x_lengths.fill_(x.size(0)) + + if self.fc_out is not None: + x = self.fc_out(x) # T x B x C -> T x B x V + + return EncoderOut( + encoder_out=x, # T x B x C + encoder_padding_mask=out.encoder_padding_mask.transpose(0, 1), # T x B + encoder_embedding=out.encoder_embedding, # None + encoder_states=out.encoder_states, # List[T x B x C] + src_tokens=out.src_tokens, # None + src_lengths=x_lengths, # B + ) + + +@register_model_architecture("speech_transformer_encoder_model", "speech_transformer_encoder_model") +def base_architecture(args): + args.encoder_conv_channels = getattr( + args, "encoder_conv_channels", "[64, 64, 128, 128]", + ) + args.encoder_conv_kernel_sizes = getattr( + args, "encoder_conv_kernel_sizes", "[(3, 3), (3, 3), (3, 3), (3, 3)]", + ) + args.encoder_conv_strides = getattr( + args, "encoder_conv_strides", "[(1, 1), (2, 2), (1, 1), (2, 2)]", + ) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.encoder_transformer_context = getattr(args, "encoder_transformer_context", None) + args.attention_dropout = getattr(args, "attention_dropout", 0.2) + args.activation_dropout = getattr(args, "activation_dropout", 0.2) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.2) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + + +@register_model_architecture("speech_transformer_encoder_model", "speech_transformer_encoder_model_wsj") +def speech_transformer_encoder_wsj(args): + base_architecture(args) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 1ea9240df..892fcf5c0 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -35,7 +35,7 @@ def main(args): if args.results_path is not None: os.makedirs(args.results_path, exist_ok=True) output_path = os.path.join(args.results_path, 'decode.log') - with open(output_path, 'w', buffering=1) as h: + with open(output_path, 'w', buffering=1, encoding='utf-8') as h: return _main(args, h) return _main(args, sys.stdout) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 37d5e789f..c5a9c8228 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -16,6 +16,7 @@ train_set=train_960 valid_set=dev test_set="test_clean test_other dev_clean dev_other" checkpoint=checkpoint_best.pt +use_transformer=false # LM related lm_affix= @@ -43,7 +44,11 @@ apply_specaug=false . ./utils/parse_options.sh lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} -dir=exp/lstm${affix:+_$affix} +if $use_transformer; then + dir=exp/transformer${affix:+_$affix} +else + dir=exp/lstm${affix:+_$affix} +fi if [ ${stage} -le 0 ]; then echo "Stage 0: Data Downloading" @@ -221,24 +226,34 @@ if [ ${stage} -le 8 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" opts="" - if $apply_specaug; then - opts="$opts --max-epoch 95 --lr-scheduler tri_stage --warmup-steps $((2000/ngpus)) --hold-steps $((600000/ngpus)) --decay-steps $((1040000/ngpus))" - opts="$opts --encoder-rnn-layers 5" - specaug_config="{'W': 80, 'F': 27, 'T': 100, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 1.0}" + if $use_transformer; then + update_freq=$(((8+ngpus-1)/ngpus)) + opts="$opts --arch speech_transformer_librispeech --max-tokens 22000 --max-epoch 100 --lr-scheduler tri_stage" + opts="$opts --warmup-steps $((25000/ngpus/update_freq)) --hold-steps $((900000/ngpus/update_freq)) --decay-steps $((1550000/ngpus/update_freq))" + if $apply_specaug; then + specaug_config="{'W': 80, 'F': 27, 'T': 100, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 1.0}" + fi else - opts="$opts --max-epoch 30 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10" + update_freq=$(((2+ngpus-1)/ngpus)) + opts="$opts --arch speech_conv_lstm_librispeech" + if $apply_specaug; then + opts="$opts --max-epoch 95 --lr-scheduler tri_stage" + opts="$opts --warmup-steps $((2000/ngpus/update_freq)) --hold-steps $((600000/ngpus/update_freq)) --decay-steps $((1040000/ngpus/update_freq))" + opts="$opts --encoder-rnn-layers 5" + specaug_config="{'W': 80, 'F': 27, 'T': 100, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 1.0}" + else + opts="$opts --max-epoch 30 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10" + fi fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ - --log-interval $((8000/ngpus)) --log-format simple --print-training-sample-interval $((4000/ngpus)) \ - --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 \ - --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d \ + --log-interval $((8000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ + --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 --empty-cache-freq 50 \ + --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((6000/ngpus)) \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((6000/ngpus/update_freq)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ - --arch speech_conv_lstm_librispeech --criterion label_smoothed_cross_entropy_v2 \ - --label-smoothing 0.1 --smoothing-type uniform \ - --scheduled-sampling-probs 1.0 --start-scheduled-sampling-epoch 1 \ + --criterion label_smoothed_cross_entropy_v2 --label-smoothing 0.1 --smoothing-type uniform \ --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model \ --max-source-positions 9999 --max-target-positions 999 \ $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index e477abac4..cc748b57a 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -16,6 +16,7 @@ train_set=train_nodup valid_set=train_dev test_set="train_dev eval2000 rt03" checkpoint=checkpoint_best.pt +use_transformer=false # LM related lm_affix= @@ -48,7 +49,11 @@ apply_specaug=false . ./utils/parse_options.sh lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} -dir=exp/lstm${affix:+_$affix} +if $use_transformer; then + dir=exp/transformer${affix:+_$affix} +else + dir=exp/lstm${affix:+_$affix} +fi if [ $stage -le 0 ]; then echo "Stage 0: Data Preparation" @@ -261,25 +266,33 @@ if [ $stage -le 7 ]; then mkdir -p $dir/log log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" - if $apply_specaug; then - opts="$opts --max-epoch 100 --lr-scheduler tri_stage --warmup-steps $((1000/ngpus)) --hold-steps $((140000/ngpus)) --decay-steps $((330000/ngpus))" - opts="$opts --encoder-rnn-hidden-size 1024 --encoder-rnn-layers 5 --decoder-embed-dim 512 --decoder-hidden-size 1024" - opts="$opts --decoder-out-embed-dim 3072 --attention-dim 512 --dropout 0.4" - specaug_config="{'W': 40, 'F': 18, 'T': 70, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 0.2}" + update_freq=$(((2+ngpus-1)/ngpus)) + if $use_transformer; then + opts="$opts --arch speech_transformer_swbd --max-epoch 100 --lr-scheduler tri_stage" + opts="$opts --warmup-steps $((25000/ngpus/update_freq)) --hold-steps $((180000/ngpus/update_freq)) --decay-steps $((320000/ngpus/update_freq))" + if $apply_specaug; then + specaug_config="{'W': 40, 'F': 18, 'T': 70, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 0.2}" + fi else - opts="$opts --max-epoch 35 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14" + opts="$opts --arch speech_conv_lstm_swbd --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6" + if $apply_specaug; then + opts="$opts --max-epoch 100 --lr-scheduler tri_stage --warmup-steps $((1000/ngpus/update_freq)) --hold-steps $((180000/ngpus/update_freq)) --decay-steps $((360000/ngpus/update_freq))" + opts="$opts --encoder-rnn-hidden-size 1024 --encoder-rnn-layers 5 --decoder-embed-dim 512 --decoder-hidden-size 1024" + opts="$opts --decoder-out-embed-dim 3072 --attention-dim 512 --dropout 0.4" + specaug_config="{'W': 40, 'F': 18, 'T': 70, 'num_freq_masks': 2, 'num_time_masks': 2, 'p': 0.2}" + else + opts="$opts --max-epoch 35 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14" + fi fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ - --log-interval $((3000/ngpus)) --log-format simple --print-training-sample-interval $((4000/ngpus)) \ - --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 \ - --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ + --log-interval $((3000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ + --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 --empty-cache-freq 50 \ + --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((3000/ngpus)) \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((3000/ngpus/update_freq)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ - --arch speech_conv_lstm_swbd --criterion label_smoothed_cross_entropy_v2 \ - --label-smoothing 0.1 --smoothing-type uniform \ - --scheduled-sampling-probs 0.9,0.8,0.7,0.6 --start-scheduled-sampling-epoch 6 \ + --criterion label_smoothed_cross_entropy_v2 --label-smoothing 0.1 --smoothing-type uniform \ --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 9d7931aba..4171d078c 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -16,6 +16,7 @@ train_set=train_si284 valid_set=test_dev93 test_set=test_eval92 checkpoint=checkpoint_best.pt +use_transformer=false # LM related lm_affix= @@ -45,7 +46,11 @@ do_delta=false lmdir=exp/lm_lstm${lm_affix:+_${lm_affix}} wordlmdir=exp/wordlm_lstm${wordlm_affix:+_${wordlm_affix}} -dir=exp/lstm${affix:+_$affix} +if $use_transformer; then + dir=exp/transformer${affix:+_$affix} +else + dir=exp/lstm${affix:+_$affix} +fi if [ ${stage} -le 0 ]; then echo "Stage 0: Data Preparation" @@ -269,18 +274,24 @@ if [ ${stage} -le 9 ]; then mkdir -p $dir/log log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + update_freq=$(((2+ngpus-1)/ngpus)) + if $use_transformer; then + opts="$opts --arch speech_transformer_wsj --max-epoch 100 --lr-scheduler tri_stage" + opts="$opts --warmup-steps $((25000/ngpus/update_freq)) --hold-steps $((60000/ngpus/update_freq)) --decay-steps $((100000/ngpus/update_freq))" + else + opts="$opts --arch speech_conv_lstm_wsj --max-epoch 35 --lr-scheduler reduce_lr_on_plateau_v2" + opts="$opts --lr-shrink 0.5 --start-reduce-lr-epoch 11" + opts="$opts --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6" + fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ - --log-interval $((800/ngpus)) --log-format simple --print-training-sample-interval $((2000/ngpus)) \ - --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 \ - --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d \ + --log-interval $((800/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((2000/ngpus/update_freq)) \ + --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 --empty-cache-freq 50 \ + --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ - --max-epoch 35 --optimizer adam --lr 0.001 --weight-decay 0.0 \ - --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 11 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((800/ngpus)) \ + --optimizer adam --lr 0.001 --weight-decay 0.0 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((800/ngpus/update_freq)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ - --arch speech_conv_lstm_wsj --criterion label_smoothed_cross_entropy_v2 \ - --label-smoothing 0.05 --smoothing-type temporal \ - --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6 \ + --criterion label_smoothed_cross_entropy_v2 --label-smoothing 0.05 --smoothing-type temporal \ --dict $dict --bpe characters_asr --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 $opts 2>&1 | tee $log_file fi From 8600f76c45f1a41246eb3c9ec1c92e0946d53eac Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 7 Jul 2020 21:25:29 -0400 Subject: [PATCH 088/119] ignore flake8's FileNotFoundError for soft links to kaldi files; code adaptation/changes according to the commits on Jul 8, 2020 --- .github/workflows/build.yml | 2 +- espresso/data/asr_dataset.py | 2 +- espresso/data/asr_xent_dataset.py | 4 +-- espresso/dump_posteriors.py | 9 ++++++- espresso/models/speech_lstm.py | 24 ++++++++--------- espresso/models/speech_lstm_encoder_model.py | 4 +-- espresso/models/speech_tdnn.py | 13 ++++----- espresso/models/speech_transformer.py | 11 ++++---- .../speech_transformer_encoder_model.py | 9 ++++--- espresso/speech_recognize.py | 12 ++++++--- espresso/speech_train.py | 27 ++++++++++++++++--- .../tools/generate_log_probs_for_decoding.py | 8 ++---- espresso/tools/simple_greedy_decoder.py | 9 +++---- tests/espresso/test_asr_dataset.py | 2 -- 14 files changed, 79 insertions(+), 57 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 29e5254d3..85b95a607 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,7 +45,7 @@ jobs: run: | pip install flake8 # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron + flake8 . --count --select=E9,F63,F7,F82 --ignore=E902 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 549b2611e..eb4afc28e 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -193,7 +193,7 @@ def _match_src_tgt(self): def get_batch_shapes(self): return self.buckets - + def __getitem__(self, index): tgt_item = self.tgt[index][0] if self.tgt is not None else None raw_text_item = self.tgt[index][1] if self.tgt is not None else None diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index 3b3057484..2135983c3 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -415,7 +415,7 @@ def __init__( ) self.tgt_sizes = self.tgt.sizes logger.info("bucketing target lengths: {}".format(list(self.tgt.buckets))) - + # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) @@ -476,7 +476,7 @@ def _match_src_text(self): return True def get_batch_shapes(self): - return self.buckets + return self.buckets def __getitem__(self, index): tgt_item = self.tgt[index] if self.tgt is not None else None diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index 9c609488f..5b50e2cbd 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -11,6 +11,8 @@ import logging import sys +import numpy as np + import torch from fairseq import checkpoint_utils, options, tasks, utils @@ -45,6 +47,11 @@ def _main(args, output_file): args.max_tokens = 12000 logger.info(args) + # Fix seed for stochastic decoding + if args.seed is not None and not args.no_seed_provided: + np.random.seed(args.seed) + torch.manual_seed(args.seed) + use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset split @@ -68,7 +75,7 @@ def _main(args, output_file): # Optimize ensemble for generation for model in models: - model.make_generation_fast_() + model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 7d9fec93d..670d2663e 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -27,7 +27,7 @@ LSTMCell, Linear, ) -from fairseq.modules import AdaptiveSoftmax +from fairseq.modules import AdaptiveSoftmax, FairseqDropout from espresso.modules import speech_attention from espresso.tools.scheduled_sampling_rate_scheduler import ScheduledSamplingRateScheduler @@ -152,7 +152,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_decoder_embed = load_pretrained_embedding_from_file( args.decoder_embed_path, task.target_dictionary, - args.decoder_embed_dim + args.decoder_embed_dim, ) # one last double check of parameter combinations if args.share_decoder_input_output_embed and ( @@ -335,8 +335,8 @@ def __init__( super().__init__(None) # no src dictionary self.conv_layers_before = conv_layers_before self.num_layers = num_layers - self.dropout_in = dropout_in - self.dropout_out = dropout_out + self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) + self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) self.bidirectional = bidirectional self.hidden_size = hidden_size self.residual = residual @@ -397,7 +397,7 @@ def forward( bsz, seqlen = x.size(0), x.size(1) - x = F.dropout(x, p=self.dropout_in, training=self.training) + x = self.dropout_in_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -424,7 +424,7 @@ def forward( # unpack outputs and apply dropout x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value*1.0) if i < len(self.lstm) - 1: # not applying dropout for the last layer - x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.dropout_out_module(x) x = x + prev_x if self.residual and i > 0 else x assert list(x.size()) == [seqlen, bsz, self.output_units] @@ -477,8 +477,8 @@ def __init__( scheduled_sampling_rate_scheduler=None, ): super().__init__(dictionary) - self.dropout_in = dropout_in - self.dropout_out = dropout_out + self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) + self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed if attn_type is None or attn_type.lower() == "none": @@ -527,7 +527,7 @@ def __init__( if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined self.adaptive_softmax = AdaptiveSoftmax( - num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out + num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out, ) elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) @@ -629,7 +629,7 @@ def extract_features( # embed tokens x = self.embed_tokens(prev_output_tokens) - x = F.dropout(x, p=self.dropout_in, training=self.training) + x = self.dropout_in_module(x) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -672,7 +672,7 @@ def extract_features( input = torch.cat((hidden, context), dim=1) else: input = hidden - input = F.dropout(input, p=self.dropout_out, training=self.training) + input = self.dropout_out_module(input) if self.residual and i > 0: if encoder_out is not None: hidden_sum = input[:, :hidden.size(1)] + prev_layer_hidden @@ -713,7 +713,7 @@ def extract_features( if hasattr(self, "additional_fc") and self.adaptive_softmax is None: x = self.additional_fc(x) - x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.dropout_out_module(x) # srclen x tgtlen x bsz -> bsz x tgtlen x srclen if not self.training and encoder_out is not None and self.need_attn: assert attn_scores is not None diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py index 5c661aa24..4940993f8 100644 --- a/espresso/models/speech_lstm_encoder_model.py +++ b/espresso/models/speech_lstm_encoder_model.py @@ -160,7 +160,7 @@ def __init__( ): super().__init__( conv_layers_before=conv_layers_before, input_size=input_size, hidden_size=hidden_size, - num_layers=num_layers, dropout_in=dropout_in, dropout_out=dropout_in, + num_layers=num_layers, dropout_in=dropout_in, dropout_out=dropout_out, bidirectional=bidirectional, residual=residual, left_pad=left_pad, padding_value=padding_value, src_bucketed=src_bucketed, max_source_positions=max_source_positions, ) @@ -175,7 +175,7 @@ def __init__( self.training_stage = training_stage # only for encoder-only model - self.fc_out = Linear(self.output_units, num_targets, dropout=dropout_out) \ + self.fc_out = Linear(self.output_units, num_targets, dropout=self.dropout_out_module.p) \ if num_targets is not None else None def forward( diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index de470ca05..2ceee1423 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -20,6 +20,7 @@ ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import Linear +from fairseq.modules import FairseqDropout import espresso.tools.utils as speech_utils @@ -191,8 +192,8 @@ def __init__( dilations = [dilations] * num_layers else: assert len(dilations) == num_layers - self.dropout_in = dropout_in - self.dropout_out = dropout_out + self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) + self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) self.residual = residual self.tdnn = nn.ModuleList([ @@ -204,7 +205,7 @@ def __init__( for layer in range(num_layers) ]) - receptive_field_radius = sum(l.padding for l in self.tdnn) + receptive_field_radius = sum(layer.padding for layer in self.tdnn) assert chunk_width is None or (chunk_width > 0 and chunk_left_context >= receptive_field_radius) if ( chunk_width is not None and chunk_width > 0 @@ -216,7 +217,7 @@ def __init__( if chunk_width is not None else None self.training_stage = training_stage - self.fc_out = Linear(hidden_sizes[-1], output_size, dropout=dropout_out) + self.fc_out = Linear(hidden_sizes[-1], output_size, dropout=self.dropout_out_module.p) def output_lengths(self, in_lengths): out_lengths = in_lengths @@ -248,14 +249,14 @@ def forward(self, src_tokens, src_lengths: Tensor, **unused): def extract_features(self, src_tokens, src_lengths, **unused): x, x_lengths = src_tokens, src_lengths - x = F.dropout(x, p=self.dropout_in, training=self.training) + x = self.dropout_in_module(x) for i in range(len(self.tdnn)): if self.residual and i > 0: # residual connection starts from the 2nd layer prev_x = x # apply Tdnn x, x_lengths, padding_mask = self.tdnn[i](x, x_lengths) - x = F.dropout(x, p=self.dropout_out, training=self.training) + x = self.dropout_out_module(x) x = x + prev_x if self.residual and i > 0 and x.size(1) == prev_x.size(1) else x x = x.transpose(0, 1) # B x T x C -> T x B x C diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index de2875b5b..b8d556e2a 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -4,12 +4,11 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import torch from torch import Tensor import torch.nn as nn -import torch.nn.functional as F from fairseq import options from fairseq.models import ( @@ -18,13 +17,13 @@ ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( - Embedding, Linear, TransformerModel, TransformerEncoder, TransformerDecoder, ) from fairseq.modules import ( + FairseqDropout, LayerDropModuleList, LayerNorm, PositionalEmbedding, @@ -247,7 +246,7 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con super(TransformerEncoder, self).__init__(None) # no src dictionary self.register_buffer("version", torch.Tensor([3])) - self.dropout = args.dropout + self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = args.encoder_embed_dim @@ -360,7 +359,7 @@ def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): x, encoder_padding_mask = src_tokens, \ ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) if self.fc0 is not None: x = self.fc0(x) if self.embed_positions is not None: @@ -368,7 +367,7 @@ def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): x = x + self.embed_positions((~encoder_padding_mask).int()) if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) - x = F.dropout(x, p=self.dropout, training=self.training) + x = self.dropout_module(x) elif self.embed_positions is not None: # 0s in `~encoder_padding_mask` are used as pad_idx for positional embeddings x = x + self.embed_positions((~encoder_padding_mask).int()) diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index b4c49edf6..124208fa2 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -87,10 +87,10 @@ def build_model(cls, args, task): # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) - + if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) - + if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS @@ -148,7 +148,8 @@ def set_num_updates(self, num_updates): super().set_num_updates(num_updates) @classmethod - def build_encoder(cls, args, conv_layers_before=None, input_size=83, transformer_context=None, + def build_encoder( + cls, args, conv_layers_before=None, input_size=83, transformer_context=None, num_targets=None, chunk_width=None, chunk_left_context=0, training_stage=True, ): return SpeechChunkTransformerEncoder( @@ -217,7 +218,7 @@ def __init__( self.training_stage = training_stage # only for encoder-only model - self.fc_out = Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout) \ + self.fc_out = Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout_module.p) \ if num_targets is not None else None def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 892fcf5c0..405db678b 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -13,6 +13,8 @@ import os import sys +import numpy as np + import torch from fairseq import checkpoint_utils, options, tasks, utils @@ -59,6 +61,11 @@ def _main(args, output_file): args.max_tokens = 12000 logger.info(args) + # Fix seed for stochastic decoding + if args.seed is not None and not args.no_seed_provided: + np.random.seed(args.seed) + torch.manual_seed(args.seed) + use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset split @@ -103,10 +110,7 @@ def _main(args, output_file): # Optimize ensemble for generation for model in models: - model.make_generation_fast_( - beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, - need_attn=args.print_alignment, - ) + model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 7bdd3d3af..a9113255d 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -8,11 +8,13 @@ Train a new model on one or across multiple GPUs. """ +import argparse import logging import math import os import random import sys +from typing import Callable, Optional import numpy as np import torch @@ -39,7 +41,13 @@ logger = logging.getLogger("espresso.speech_rain") -def main(args, init_distributed=False): +def main( + args, + init_distributed=False, + after_distributed_init_fn: Optional[ + Callable[[argparse.Namespace], argparse.Namespace] + ] = None, +): utils.import_user_module(args) assert ( @@ -54,6 +62,8 @@ def main(args, init_distributed=False): utils.set_torch_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) + if after_distributed_init_fn: + args = after_distributed_init_fn(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) @@ -234,7 +244,7 @@ def train(args, trainer, task, epoch_itr): # update the state prior stored in the model for cross-entropy training if hasattr(task, "update_state_prior"): task.update_state_prior(trainer.get_model()) - + end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch @@ -337,11 +347,20 @@ def get_valid_stats(args, trainer, stats): return stats -def distributed_main(i, args, start_rank=0): +def distributed_main( + i, + args, + start_rank=0, + after_distributed_init_fn: Optional[ + Callable[[argparse.Namespace], argparse.Namespace] + ] = None, +): args.device_id = i if args.distributed_rank is None: # torch.multiprocessing.spawn args.distributed_rank = start_rank + i - main(args, init_distributed=True) + main( + args, init_distributed=True, after_distributed_init_fn=after_distributed_init_fn + ) def print_options_meaning_changes(args): diff --git a/espresso/tools/generate_log_probs_for_decoding.py b/espresso/tools/generate_log_probs_for_decoding.py index fd6022aa2..c9eedebd0 100644 --- a/espresso/tools/generate_log_probs_for_decoding.py +++ b/espresso/tools/generate_log_probs_for_decoding.py @@ -12,15 +12,13 @@ class GenerateLogProbsForDecoding(nn.Module): - def __init__(self, models, retain_dropout=False, apply_log_softmax=False): + def __init__(self, models, apply_log_softmax=False): """Generate the neural network's output intepreted as log probabilities for decoding with Kaldi. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models, currently support fairseq.models.TransformerModel for scripting - retain_dropout (bool, optional): use dropout when generating - (default: False) apply_log_softmax (bool, optional): apply log-softmax on top of the network's output (default: False) """ @@ -30,11 +28,9 @@ def __init__(self, models, retain_dropout=False, apply_log_softmax=False): self.model = models else: self.model = EnsembleModel(models) - self.retain_dropout = retain_dropout self.apply_log_softmax = apply_log_softmax - if not self.retain_dropout: - self.model.eval() + self.model.eval() def cuda(self): self.model.cuda() diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index 5cdb9815b..c210cd569 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -14,7 +14,7 @@ class SimpleGreedyDecoder(nn.Module): def __init__( - self, models, dictionary, max_len_a=0, max_len_b=200, retain_dropout=False, + self, models, dictionary, max_len_a=0, max_len_b=200, temperature=1.0, for_validation=True, ): """Decode given speech audios with the simple greedy search. @@ -25,8 +25,6 @@ def __init__( dictionary (~fairseq.data.Dictionary): dictionary max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length - retain_dropout (bool, optional): use dropout when generating - (default: False) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) @@ -47,11 +45,10 @@ def __init__( self.vocab_size = len(dictionary) self.max_len_a = max_len_a self.max_len_b = max_len_b - self.retain_dropout = retain_dropout self.temperature = temperature assert temperature > 0, "--temperature must be greater than 0" - if not self.retain_dropout: - self.model.eval() + + self.model.eval() self.for_validation = for_validation def cuda(self): diff --git a/tests/espresso/test_asr_dataset.py b/tests/espresso/test_asr_dataset.py index 2101253df..1b01836a1 100644 --- a/tests/espresso/test_asr_dataset.py +++ b/tests/espresso/test_asr_dataset.py @@ -127,8 +127,6 @@ def _asr_dataset_helper( tgt_dataset, tgt_dataset.sizes, self.dictionary, left_pad_source=False, left_pad_target=False, - max_source_positions=1000, - max_target_positions=200, ) # assume one is a subset of the other From bc6901beab19e2e71bd2a900297a159136d6760d Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 15 Jul 2020 16:38:39 -0400 Subject: [PATCH 089/119] code adaptation/changes according to the commits on Jul 14, 2020 --- espresso/data/asr_dataset.py | 6 ++++-- espresso/dump_posteriors.py | 2 +- espresso/models/speech_transformer.py | 9 +++++---- espresso/speech_recognize.py | 2 +- espresso/speech_train.py | 8 +++++++- espresso/tasks/speech_recognition.py | 4 +++- espresso/tasks/speech_recognition_hybrid.py | 4 ++++ examples/asr_librispeech/run.sh | 4 ++-- examples/asr_swbd/run.sh | 4 ++-- 9 files changed, 29 insertions(+), 14 deletions(-) diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index eb4afc28e..dba413de0 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -32,7 +32,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): return speech_utils.collate_frames( [s[key] for s in samples], 0.0, left_pad, ) - elif key == 'target': + elif key == 'target' or key == 'prev_output_tokens': return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, @@ -61,7 +61,9 @@ def merge(key, left_pad, move_eos_to_beginning=False): target = target.index_select(0, sort_order) ntokens = sum(s['target'].ne(pad_idx).int().sum().item() for s in samples) - if input_feeding: + if samples[0].get('prev_output_tokens', None) is not None: + prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target) + elif input_feeding: # we create a shifted version of targets for feeding the # previous output token(s) into the next decoder step prev_output_tokens = merge( diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index 5b50e2cbd..7020beace 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -50,7 +50,7 @@ def _main(args, output_file): # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) - torch.manual_seed(args.seed) + utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index b8d556e2a..804f52f8f 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -268,6 +268,11 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con else None ) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + if not args.adaptive_input and args.quant_noise_pq > 0: self.quant_noise = apply_quant_noise_( nn.Linear(embed_dim, embed_dim, bias=False), @@ -290,10 +295,6 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None - if getattr(args, "layernorm_embedding", False): - self.layernorm_embedding = LayerNorm(embed_dim) - else: - self.layernorm_embedding = None self.transformer_context = transformer_context diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 405db678b..6a2f52561 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -64,7 +64,7 @@ def _main(args, output_file): # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) - torch.manual_seed(args.seed) + utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu diff --git a/espresso/speech_train.py b/espresso/speech_train.py index a9113255d..87d7ec178 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -38,7 +38,7 @@ level=logging.INFO, stream=sys.stdout, ) -logger = logging.getLogger("espresso.speech_rain") +logger = logging.getLogger("espresso.speech_train") def main( @@ -194,6 +194,8 @@ def tpu_data_loader(args, itr): @metrics.aggregate("train") def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" + logger.info("begin training epoch {}".format(epoch_itr.epoch)) + # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, @@ -253,6 +255,7 @@ def train(args, trainer, task, epoch_itr): break # log end-of-epoch stats + logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) @@ -287,6 +290,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc # Save checkpoint if do_save or should_stop: + logger.info("begin save checkpoint") checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop @@ -306,6 +310,8 @@ def validate(args, trainer, task, epoch_itr, subsets): valid_losses = [] for subset in subsets: + logger.info("begin validation on \"{}\" subset".format(subset)) + # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index f2c918e32..921bc725b 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -30,7 +30,7 @@ def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, - num_buckets=0, + num_buckets=0, shuffle=True, seed=1, specaugment_config=None, ): """ @@ -114,6 +114,7 @@ def get_asr_dataset_from_json( left_pad_source=False, left_pad_target=False, num_buckets=num_buckets, + shuffle=shuffle, ) @@ -240,6 +241,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, upsample_primary=self.args.upsample_primary, num_buckets=self.args.num_batch_buckets, + shuffle=(split != getattr(self.args, "gen_subset", "")), seed=self.args.seed, specaugment_config=self.specaugment_config, ) diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index 68011d258..ad0f1feb9 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -39,6 +39,7 @@ def get_asr_dataset_from_json( data_path, split, dictionary, combine, upsample_primary, num_buckets=0, + shuffle=True, lf_mmi=True, seed=1, specaugment_config=None, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, @@ -147,6 +148,7 @@ def get_asr_dataset_from_json( tgt_dataset, tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, + shuffle=shuffle, ) else: return AsrXentDataset( @@ -154,6 +156,7 @@ def get_asr_dataset_from_json( tgt_dataset, tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, + shuffle=shuffle, seed=seed, chunk_width=chunk_width, chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), @@ -321,6 +324,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, upsample_primary=self.args.upsample_primary, num_buckets=self.args.num_batch_buckets, + shuffle=(split != getattr(self.args, "gen_subset", "")), lf_mmi=(self.args.criterion == "lattice_free_mmi"), seed=self.args.seed, specaugment_config=self.specaugment_config, chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index c5a9c8228..d1d866767 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -254,7 +254,7 @@ if [ ${stage} -le 8 ]; then --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((6000/ngpus/update_freq)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --criterion label_smoothed_cross_entropy_v2 --label-smoothing 0.1 --smoothing-type uniform \ - --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model \ + --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --max-source-positions 9999 --max-target-positions 999 \ $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file fi @@ -277,7 +277,7 @@ if [ ${stage} -le 9 ]; then decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --max-sentences 24 \ - --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model \ + --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index cc748b57a..221f520fe 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -293,7 +293,7 @@ if [ $stage -le 7 ]; then --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((3000/ngpus/update_freq)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ --criterion label_smoothed_cross_entropy_v2 --label-smoothing 0.1 --smoothing-type uniform \ - --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model --non-lang-syms $nlsyms \ + --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model --non-lang-syms $nlsyms \ --max-source-positions 9999 --max-target-positions 999 \ $opts --specaugment-config "$specaug_config" 2>&1 | tee $log_file fi @@ -314,7 +314,7 @@ if [ $stage -le 8 ]; then decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ - --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-vocab ${sentencepiece_model}.model \ + --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --non-lang-syms $nlsyms --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ --results-path $decode_dir $opts From 7c21d3de63f0d134132b8968245680e685f75401 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 16 Jul 2020 21:56:03 -0400 Subject: [PATCH 090/119] code adaptation/changes according to the commits on Jul 16, 2020 --- espresso/data/asr_chain_dataset.py | 20 +++-- espresso/data/asr_dataset.py | 52 +++++++++++-- espresso/data/asr_xent_dataset.py | 28 ++++--- espresso/speech_train.py | 86 +++------------------ espresso/tasks/language_modeling_for_asr.py | 29 ++----- espresso/tools/utils.py | 3 +- examples/asr_librispeech/run.sh | 4 +- examples/asr_swbd/run.sh | 4 +- examples/asr_wsj/run.sh | 6 +- examples/asr_wsj/run_chain_e2e.sh | 10 ++- examples/asr_wsj/run_xent.sh | 10 ++- 11 files changed, 115 insertions(+), 137 deletions(-) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 1258982ca..c819e0453 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -def collate(samples, src_bucketed=False): +def collate(samples, pad_to_length=None, src_bucketed=False): try: from pychain import ChainGraphBatch except ImportError: @@ -29,9 +29,12 @@ def collate(samples, src_bucketed=False): if len(samples) == 0: return {} - def merge(key): + def merge(key, pad_to_length=None): if key == "source": - return speech_utils.collate_frames([s[key] for s in samples], 0.0) + return speech_utils.collate_frames( + [s[key] for s in samples], 0.0, + pad_to_length=pad_to_length, + ) elif key == "target": max_num_transitions = max(s["target"].num_transitions for s in samples) max_num_states = max(s["target"].num_states for s in samples) @@ -44,9 +47,9 @@ def merge(key): raise ValueError("Invalid key.") id = torch.LongTensor([s["id"] for s in samples]) - src_frames = merge("source") + src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) # sort by descending source length - if src_bucketed: + if pad_to_length is not None or src_bucketed: src_lengths = torch.IntTensor([ s["source"].ne(0.0).any(dim=1).int().sum() for s in samples ]) @@ -280,11 +283,14 @@ def __getitem__(self, index): def __len__(self): return len(self.src) - def collater(self, samples): + def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {'source': source_pad_to_length} + to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: @@ -304,7 +310,7 @@ def collater(self, samples): numerator graphs - `text` (List[str]): list of original text """ - return collate(samples, src_bucketed=(self.buckets is not None)) + return collate(samples, pad_to_length=pad_to_length, src_bucketed=(self.buckets is not None)) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index dba413de0..799ee82cc 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -22,28 +22,34 @@ def collate( left_pad_source=True, left_pad_target=False, input_feeding=True, + pad_to_length=None, src_bucketed=False, ): if len(samples) == 0: return {} - def merge(key, left_pad, move_eos_to_beginning=False): + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): if key == 'source': return speech_utils.collate_frames( [s[key] for s in samples], 0.0, left_pad, + pad_to_length=pad_to_length, ) elif key == 'target' or key == 'prev_output_tokens': return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, + pad_to_length=pad_to_length, ) else: raise ValueError('Invalid key.') id = torch.LongTensor([s['id'] for s in samples]) - src_frames = merge('source', left_pad=left_pad_source) + src_frames = merge( + 'source', left_pad=left_pad_source, + pad_to_length=pad_to_length['source'] if pad_to_length is not None else None, + ) # sort by descending source length - if src_bucketed: + if pad_to_length is not None or src_bucketed: src_lengths = torch.IntTensor([ s['source'].ne(0.0).any(dim=1).int().sum() for s in samples ]) @@ -57,7 +63,10 @@ def merge(key, left_pad, move_eos_to_beginning=False): prev_output_tokens = None target = None if samples[0].get('target', None) is not None: - target = merge('target', left_pad=left_pad_target) + target = merge( + 'target', left_pad=left_pad_target, + pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + ) target = target.index_select(0, sort_order) ntokens = sum(s['target'].ne(pad_idx).int().sum().item() for s in samples) @@ -70,6 +79,7 @@ def merge(key, left_pad, move_eos_to_beginning=False): 'target', left_pad=left_pad_target, move_eos_to_beginning=True, + pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, ) prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: @@ -116,6 +126,12 @@ class AsrDataset(FairseqDataset): to be passed into the model for teacher forcing (default: True). num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. + src_lang_id (int, optional): source language ID, if set, the collated batch + will contain a field 'src_lang_id' in 'net_input' which indicates the + source language of the samples. + tgt_lang_id (int, optional): target language ID, if set, the collated batch + will contain a field 'tgt_lang_id' which indicates the target language + of the samples. """ def __init__( @@ -124,6 +140,8 @@ def __init__( left_pad_source=False, left_pad_target=False, shuffle=True, input_feeding=True, num_buckets=0, + src_lang_id=None, + tgt_lang_id=None, ): self.src = src self.tgt = tgt @@ -134,6 +152,8 @@ def __init__( self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding + self.src_lang_id = src_lang_id + self.tgt_lang_id = tgt_lang_id if self.tgt is not None: self._match_src_tgt() @@ -212,11 +232,14 @@ def __getitem__(self, index): def __len__(self): return len(self.src) - def collater(self, samples): + def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {'source': source_pad_to_length, 'target': target_pad_to_length} + to indicate the max length to pad to in source and target respectively. Returns: dict: a mini-batch with the following keys: @@ -238,21 +261,38 @@ def collater(self, samples): This key will not be present if *input_feeding* is ``False``. Padding will appear on the left if *left_pad_target* is ``True``. + - `src_lang_id` (LongTensor): a long Tensor which contains source + language IDs of each sample in the batch - `target` (LongTensor): a padded 2D Tensor of tokens in the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the left if *left_pad_target* is ``True``. - `target_raw_text` (List[str]): list of original text + - `tgt_lang_id` (LongTensor): a long Tensor which contains target language + IDs of each sample in the batch """ - return collate( + res = collate( samples, pad_idx=self.dictionary.pad(), eos_idx=self.dictionary.eos(), left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, + pad_to_length=pad_to_length, src_bucketed=(self.buckets is not None), ) + if self.src_lang_id is not None or self.tgt_lang_id is not None: + src_tokens = res['net_input']['src_tokens'] + bsz = src_tokens.size(0) + if self.src_lang_id is not None: + res['net_input']['src_lang_id'] = torch.LongTensor( + [[self.src_lang_id]] + ).expand(bsz, 1).to(src_tokens) + if self.tgt_lang_id is not None: + res['tgt_lang_id'] = torch.LongTensor( + [[self.tgt_lang_id]] + ).expand(bsz, 1).to(src_tokens) + return res def num_tokens(self, index): """Return the number of frames in a sample. This value is used to diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index 2135983c3..b96c60f72 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -34,20 +34,25 @@ def collate( label_delay, seed, epoch, + pad_to_length=None, src_bucketed=False, random_chunking=True, ): if len(samples) == 0: return {} - def merge(key): + def merge(key, pad_to_length=None): if key == "source": - return speech_utils.collate_frames([s[key] for s in samples], 0.0) + return speech_utils.collate_frames( + [s[key] for s in samples], 0.0, + pad_to_length=pad_to_length, + ) elif key == "target": return data_utils.collate_tokens( [s[key] for s in samples], pad_idx=pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, + pad_to_length=pad_to_length, ) else: raise ValueError("Invalid key.") @@ -102,7 +107,7 @@ def chunking(src_item, tgt_item, tgt_start): else: s["source"] = src_item[: label_delay] - if src_bucketed: + if pad_to_length is not None or src_bucketed: src_lengths = torch.IntTensor([ s["source"].ne(0.0).any(dim=1).int().sum() for s in samples ]) @@ -110,11 +115,11 @@ def chunking(src_item, tgt_item, tgt_start): src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) id = torch.LongTensor([s["id"] for s in samples]) utt_id = [s["utt_id"] for s in samples] - src_frames = merge("source") + src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) target = None if samples[0].get("target", None) is not None: - target = merge("target") + target = merge("target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() @@ -148,7 +153,7 @@ def chunking(src_item, tgt_item, tgt_start): } return batch else: # sequential chunking, usually for chunk-wise test data - if src_bucketed: + if pad_to_length is not None or src_bucketed: src_lengths = torch.IntTensor([ s["source"].ne(0.0).any(dim=1).int().sum() for s in samples ]) @@ -175,12 +180,12 @@ def chunking(src_item, tgt_item, tgt_start): ) s["target"] = ori_target[i].new_full((chunk_width,), pad_idx) \ if ori_target[i] is not None else None - src_frames = merge("source") + src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) src_chunk_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) target = None if samples[0].get("target", None) is not None: - target = merge("target") + target = merge("target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() @@ -494,11 +499,15 @@ def __getitem__(self, index): def __len__(self): return len(self.src) - def collater(self, samples): + def collater(self, samples, pad_to_length=None): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {'source': source_pad_to_length, 'target': target_pad_to_length} + to indicate the max length to pad to in source and target respectively. + Returns: dict: a mini-batch with the following keys: @@ -528,6 +537,7 @@ def collater(self, samples): label_delay=self.label_delay, seed=self.seed, epoch=self.epoch, + pad_to_length=pad_to_length, src_bucketed=(self.buckets is not None), random_chunking=self.random_chunking, ) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 87d7ec178..576b26099 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -8,13 +8,9 @@ Train a new model on one or across multiple GPUs. """ -import argparse import logging import math -import os -import random import sys -from typing import Callable, Optional import numpy as np import torch @@ -41,13 +37,7 @@ logger = logging.getLogger("espresso.speech_train") -def main( - args, - init_distributed=False, - after_distributed_init_fn: Optional[ - Callable[[argparse.Namespace], argparse.Namespace] - ] = None, -): +def main(args): utils.import_user_module(args) assert ( @@ -55,15 +45,8 @@ def main( ), "Must specify batch size either with --max-tokens or --max-sentences" metrics.reset() - # Initialize CUDA and distributed training - if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): - torch.cuda.set_device(args.device_id) np.random.seed(args.seed) utils.set_torch_seed(args.seed) - if init_distributed: - args.distributed_rank = distributed_utils.distributed_init(args) - if after_distributed_init_fn: - args = after_distributed_init_fn(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) @@ -131,6 +114,7 @@ def main( lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() + while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) @@ -143,7 +127,7 @@ def main( epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch - load_dataset=(os.pathsep in getattr(args, "data", "")), + load_dataset=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) @@ -251,6 +235,12 @@ def train(args, trainer, task, epoch_itr): valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) + + if args.stop_time_hours > 0: + elapsed_hours = trainer.cumulative_training_time() / (60 * 60) + if elapsed_hours > args.stop_time_hours: + should_stop = True + if should_stop: break @@ -353,22 +343,6 @@ def get_valid_stats(args, trainer, stats): return stats -def distributed_main( - i, - args, - start_rank=0, - after_distributed_init_fn: Optional[ - Callable[[argparse.Namespace], argparse.Namespace] - ] = None, -): - args.device_id = i - if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = start_rank + i - main( - args, init_distributed=True, after_distributed_init_fn=after_distributed_init_fn - ) - - def print_options_meaning_changes(args): """Options that have different meanings than those in the translation task are explained here. @@ -383,47 +357,9 @@ def cli_main(modify_parser=None): if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - cli_main_helper(args) - else: - cli_main_helper(args) - - -def cli_main_helper(args): - if args.distributed_init_method is None: - distributed_utils.infer_init_method(args) - - if args.distributed_init_method is not None: - # distributed training - if torch.cuda.device_count() > 1 and not args.distributed_no_spawn: - start_rank = args.distributed_rank - args.distributed_rank = None # assign automatically - torch.multiprocessing.spawn( - fn=distributed_main, - args=(args, start_rank), - nprocs=torch.cuda.device_count(), - ) - else: - distributed_main(args.device_id, args) - elif args.distributed_world_size > 1: - if not getattr(args, "tpu", False): - # fallback for single node with multiple GPUs - assert args.distributed_world_size <= torch.cuda.device_count() - port = random.randint(10000, 20000) - args.distributed_init_method = "tcp://localhost:{port}".format(port=port) - args.distributed_rank = None # set based on device id - torch.multiprocessing.spawn( - fn=distributed_main, args=(args,), nprocs=args.distributed_world_size - ) - else: - import torch_xla.distributed.xla_multiprocessing as xmp - - torch.multiprocessing.set_sharing_strategy("file_system") - xmp.spawn( - fn=distributed_main, args=(args,), nprocs=8 # use all 8 TPU cores - ) + distributed_utils.call_main(args, main) else: - # single GPU training - main(args) + distributed_utils.call_main(args, main) if __name__ == "__main__": diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 423430ad9..8f3c299e0 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -96,19 +96,16 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding return d @classmethod - def setup_task(cls, args, **kwargs): - """Setup the task (e.g., load dictionaries). - - Args: - args (argparse.Namespace): parsed command-line arguments - """ + def setup_dictionary(cls, args, **kwargs): dictionary = None output_dictionary = None if args.data: paths = utils.split_paths(args.data) assert len(paths) > 0 - dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \ + dict_path = ( + os.path.join(paths[0], "dict.txt") if args.dict is None else args.dict + ) dictionary = AsrDictionary.load(dict_path) logger.info("dictionary: {} types".format(len(dictionary))) output_dictionary = dictionary @@ -116,20 +113,4 @@ def setup_task(cls, args, **kwargs): output_dictionary = TruncatedDictionary( dictionary, args.output_dictionary_size ) - - # upgrade old checkpoints - if hasattr(args, "exclude_self_target"): - args.self_target = not args.exclude_self_target - - targets = [] - if getattr(args, "self_target", False): - targets.append("self") - if getattr(args, "future_target", False): - targets.append("future") - if getattr(args, "past_target", False): - targets.append("past") - if len(targets) == 0: - # standard language modeling - targets = ["future"] - - return cls(args, dictionary, output_dictionary, targets=targets) + return (dictionary, output_dictionary) diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index 3de35d7f8..c0550ec74 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -38,10 +38,11 @@ def tokenize(sent, space='', non_lang_syms=None): return ' '.join(tokens) -def collate_frames(values, pad_value=0.0, left_pad=False): +def collate_frames(values, pad_value=0.0, left_pad=False, pad_to_length=None): """Convert a list of 2d tensor into a padded 3d tensor.""" assert values[0].dim() == 2, "expected 2, got " + str(values[0].dim) length = max(v.size(0) for v in values) + length = length if pad_to_length is None else max(length, pad_to_length) dim = values[0].size(1) res = values[0].new(len(values), length, dim).fill_(pad_value) diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index d1d866767..5be8af65f 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -178,7 +178,7 @@ if [ ${stage} -le 5 ]; then --log-interval $((16000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 32000 --max-sentences 1024 --curriculum 1 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --distributed-world-size $ngpus \ --max-epoch 30 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates $((16000/ngpus)) \ @@ -249,7 +249,7 @@ if [ ${stage} -le 8 ]; then --log-interval $((8000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d --update-freq $update_freq \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --distributed-world-size $ngpus \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((6000/ngpus/update_freq)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 221f520fe..d59a218f5 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -216,7 +216,7 @@ if [ $stage -le 4 ]; then --log-interval $((1000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 1024 \ --valid-subset $valid_subset --max-sentences-valid 1536 \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --distributed-world-size $ngpus \ --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates $((1000/ngpus)) \ @@ -288,7 +288,7 @@ if [ $stage -le 7 ]; then --log-interval $((3000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 --empty-cache-freq 50 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --distributed-world-size $ngpus \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((3000/ngpus/update_freq)) \ --keep-interval-updates 3 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 4171d078c..6c372a198 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -194,7 +194,7 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --max-sentences 128 \ --valid-subset $valid_subset --max-sentences-valid 256 \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --distributed-world-size $ngpus \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $lmdir --restore-file checkpoint_last.pt --save-interval-updates $((4000/ngpus)) \ @@ -224,7 +224,7 @@ if [ ${stage} -le 6 ] && $use_wordlm; then --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 6400 --max-sentences 256 \ --valid-subset $valid_subset --max-sentences-valid 512 \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --distributed-world-size $ngpus \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --save-dir $wordlmdir --restore-file checkpoint_last.pt --save-interval-updates $((4000/ngpus)) \ @@ -287,7 +287,7 @@ if [ ${stage} -le 9 ]; then --log-interval $((800/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((2000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 --empty-cache-freq 50 \ --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --distributed-world-size $ngpus \ --optimizer adam --lr 0.001 --weight-decay 0.0 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((800/ngpus/update_freq)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 --best-checkpoint-metric wer \ diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh index 7331815cc..b1aa898ca 100755 --- a/examples/asr_wsj/run_chain_e2e.sh +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -184,13 +184,15 @@ if [ ${stage} -le 6 ]; then mkdir -p $dir/log log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + update_freq=1 CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e --task speech_recognition_hybrid --seed 1 --user-dir espresso \ - --log-interval $((200/ngpus)) --log-format simple --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --max-sentences 128 \ - --curriculum 1 --valid-subset $valid_subset --max-sentences-valid 128 --ddp-backend no_c10d \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --log-interval $((200/ngpus/update_freq)) --log-format simple \ + --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --max-sentences 128 --curriculum 1 --empty-cache-freq 50 \ + --valid-subset $valid_subset --max-sentences-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ + --distributed-world-size $ngpus \ --max-epoch 26 --optimizer adam --lr 0.001 --weight-decay 0.0 --start-reduce-lr-epoch 11 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((400/ngpus)) \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((400/ngpus/update_freq)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch speech_tdnn_wsj --criterion lattice_free_mmi --num-targets $num_targets \ --dropout 0.2 --kernel-sizes "[3]*6" --strides "[1]*5+[3]" --dilations "[1,1,1,3,3,3]" --num-layers 6 --residual True \ diff --git a/examples/asr_wsj/run_xent.sh b/examples/asr_wsj/run_xent.sh index b8c8f10a7..63fbcb4f6 100755 --- a/examples/asr_wsj/run_xent.sh +++ b/examples/asr_wsj/run_xent.sh @@ -164,13 +164,15 @@ if [ ${stage} -le 5 ]; then mkdir -p $dir/log log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + update_freq=1 CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/xent --task speech_recognition_hybrid --seed 1 --user-dir espresso \ - --log-interval $((100/ngpus)) --log-format simple --num-workers 0 --data-buffer-size 0 --max-tokens 160000 --max-sentences 256 \ - --valid-subset $valid_subset --max-sentences-valid 256 --ddp-backend no_c10d \ - --distributed-world-size $ngpus --distributed-port $(if [ $ngpus -gt 1 ]; then echo 100; else echo -1; fi) \ + --log-interval $((100/ngpus/update_freq)) --log-format simple \ + --num-workers 0 --data-buffer-size 0 --max-tokens 160000 --max-sentences 256 --empty-cache-freq 50 \ + --valid-subset $valid_subset --max-sentences-valid 256 --ddp-backend no_c10d --update-freq $update_freq \ + --distributed-world-size $ngpus \ --max-epoch 40 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ - --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((200/ngpus)) \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((200/ngpus/update_freq)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch speech_tdnn_wsj --criterion subsampled_cross_entropy_with_accuracy --num-targets $num_targets \ --initial-state-prior-file $state_prior_file --state-prior-update-interval 10 --state-prior-update-smoothing 0.01 \ From 4b253b1340f0548924959c694f2a4f99cc551420 Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 20 Jul 2020 23:16:27 -0400 Subject: [PATCH 091/119] code adaptation/changes according to the commits on Jul 20-25, 2020 --- espresso/speech_recognize.py | 9 ++++++++- espresso/speech_train.py | 15 ++++++++------- espresso/tasks/speech_recognition.py | 7 +++++-- espresso/tools/simple_greedy_decoder.py | 9 +++++++-- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 6a2f52561..003c390f3 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -42,6 +42,13 @@ def main(args): return _main(args, sys.stdout) +def get_symbols_to_strip_from_output(generator): + if hasattr(generator, 'symbols_to_strip_from_output'): + return generator.symbols_to_strip_from_output + else: + return {generator.eos, generator.pad} + + def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', @@ -200,7 +207,7 @@ def decode_fn(x): hypo_str = dictionary.string( hypo['tokens'].int().cpu(), bpe_symbol=None, - extra_symbols_to_ignore={dictionary.pad()}, + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) # not removing bpe at this point detok_hypo_str = decode_fn(hypo_str) if not args.quiet: diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 576b26099..7f4fd0cf2 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -212,7 +212,9 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(",") should_stop = False for i, samples in enumerate(progress): - with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( + "train_step-%d" % i + ): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue @@ -236,11 +238,6 @@ def train(args, trainer, task, epoch_itr): args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) - if args.stop_time_hours > 0: - elapsed_hours = trainer.cumulative_training_time() / (60 * 60) - if elapsed_hours > args.stop_time_hours: - should_stop = True - if should_stop: break @@ -276,6 +273,10 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc should_stop = ( should_stop_early(args, valid_losses[0]) or trainer.get_num_updates() >= max_update + or ( + args.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours + ) ) # Save checkpoint @@ -300,7 +301,7 @@ def validate(args, trainer, task, epoch_itr, subsets): valid_losses = [] for subset in subsets: - logger.info("begin validation on \"{}\" subset".format(subset)) + logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 921bc725b..e9ea648e3 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -263,7 +263,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): unk_count += (tgt_dataset[i][0] == self.tgt_dict.unk()).int().sum().item() self.tgt_dict.count[self.tgt_dict.unk()] = unk_count - def build_generator(self, models, args): + def build_generator(self, models, args, seq_gen_cls=None): if getattr(args, "score_reference", False): args.score_reference = False logger.warning( @@ -322,7 +322,10 @@ def build_generator(self, models, args): else: search_strategy = search.BeamSearch(self.target_dictionary) - return SequenceGenerator( + if seq_gen_cls is None: + seq_gen_cls = SequenceGenerator + + return seq_gen_cls( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index c210cd569..007e9555f 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -15,7 +15,8 @@ class SimpleGreedyDecoder(nn.Module): def __init__( self, models, dictionary, max_len_a=0, max_len_b=200, - temperature=1.0, for_validation=True, + temperature=1.0, eos=None, symbols_to_strip_from_output=None, + for_validation=True, ): """Decode given speech audios with the simple greedy search. @@ -41,7 +42,11 @@ def __init__( self.model = EnsembleModel(models) self.pad = dictionary.pad() self.unk = dictionary.unk() - self.eos = dictionary.eos() + self.eos = dictionary.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None else {self.eos} + ) self.vocab_size = len(dictionary) self.max_len_a = max_len_a self.max_len_b = max_len_b From d2f264baafb3b7acdda8d1ea77bb9a3174a23aef Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 28 Jul 2020 18:42:39 -0400 Subject: [PATCH 092/119] fix reorder_encoder_out in SpeechChunkTransformerEncoder; code adaptation/changes according to the commits on Jul 28, 2020 --- .../speech_transformer_encoder_model.py | 58 +++++++++++++++++++ espresso/speech_train.py | 1 + 2 files changed, 59 insertions(+) diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index 124208fa2..3ffc2abfa 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -7,6 +7,7 @@ from typing import Optional import torch +from torch import Tensor import torch.nn.functional as F from fairseq import utils @@ -267,6 +268,63 @@ def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): src_lengths=x_lengths, # B ) + @torch.jit.export + def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + """ + Since encoder_padding_mask and encoder_embedding are both of type + Optional[Tensor] in EncoderOut, they need to be copied as local + variables for Torchscript Optional refinement + """ + encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask + encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding + + new_encoder_out = ( + encoder_out.encoder_out + if encoder_out.encoder_out is None + else encoder_out.encoder_out.index_select(1, new_order) + ) + new_encoder_padding_mask = ( + encoder_padding_mask + if encoder_padding_mask is None + else encoder_padding_mask.index_select(1, new_order) # note: transposed + ) + new_encoder_embedding = ( + encoder_embedding + if encoder_embedding is None + else encoder_embedding.index_select(0, new_order) + ) + src_tokens = encoder_out.src_tokens + if src_tokens is not None: + src_tokens = src_tokens.index_select(0, new_order) + + src_lengths = encoder_out.src_lengths + if src_lengths is not None: + src_lengths = src_lengths.index_select(0, new_order) + + encoder_states = encoder_out.encoder_states + if encoder_states is not None: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return EncoderOut( + encoder_out=new_encoder_out, # T x B x C + encoder_padding_mask=new_encoder_padding_mask, # B x T + encoder_embedding=new_encoder_embedding, # B x T x C + encoder_states=encoder_states, # List[T x B x C] + src_tokens=src_tokens, # B x T + src_lengths=src_lengths, # B x 1 + ) + @register_model_architecture("speech_transformer_encoder_model", "speech_transformer_encoder_model") def base_architecture(args): diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 7f4fd0cf2..6e6f3b6a7 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -211,6 +211,7 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(",") should_stop = False + num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i From a92409cfeafebd4dec9c5d64522f04db143ccfc3 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 30 Jul 2020 19:49:54 -0400 Subject: [PATCH 093/119] reorder the elements of the returned tuple of TdnnModel.forward(); export KALDI_ROOT to adapt to the recent changes in kaldi_io; code adaptation/changes according to the commits on Aug 3-4, 2020 --- espresso/data/asr_dataset.py | 3 +-- espresso/dump_posteriors.py | 2 +- espresso/models/speech_tdnn.py | 4 ++-- espresso/models/speech_transformer.py | 2 -- .../optim/lr_scheduler/reduce_lr_on_plateau_v2.py | 6 ++++-- espresso/speech_recognize.py | 2 +- espresso/speech_train.py | 2 ++ espresso/tasks/speech_recognition.py | 11 ++++++++--- examples/asr_librispeech/path.sh | 2 +- examples/asr_swbd/path.sh | 2 +- examples/asr_wsj/path.sh | 2 +- fairseq/sequence_generator.py | 7 ++++--- 12 files changed, 26 insertions(+), 19 deletions(-) diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 799ee82cc..59eb818bc 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -81,7 +81,6 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): move_eos_to_beginning=True, pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, ) - prev_output_tokens = prev_output_tokens.index_select(0, sort_order) else: ntokens = src_lengths.sum().item() @@ -102,7 +101,7 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): 'target_raw_text': target_raw_text, } if prev_output_tokens is not None: - batch['net_input']['prev_output_tokens'] = prev_output_tokens + batch['net_input']['prev_output_tokens'] = prev_output_tokens.index_select(0, sort_order) return batch diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index 7020beace..8e90402c3 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -144,7 +144,7 @@ def _main(args, output_file): out_lengths = (~padding_mask).long().sum(dim=1).cpu() if padding_mask is not None else None num_processed_frames = sample["ntokens"] gen_timer.stop(num_processed_frames) - num_sentences += sample["nsentences"] + num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() if out_lengths is not None: for i in range(sample["nsentences"]): diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index 2ceee1423..4441270d4 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -226,7 +226,7 @@ def output_lengths(self, in_lengths): return out_lengths def forward(self, src_tokens, src_lengths: Tensor, **unused): - x, encoder_padding_mask, x_lengths = self.extract_features(src_tokens, src_lengths) + x, x_lengths, encoder_padding_mask = self.extract_features(src_tokens, src_lengths) if ( self.out_chunk_end is not None and (self.training or not self.training_stage) @@ -262,7 +262,7 @@ def extract_features(self, src_tokens, src_lengths, **unused): x = x.transpose(0, 1) # B x T x C -> T x B x C encoder_padding_mask = padding_mask.t() - return x, encoder_padding_mask, x_lengths + return x, x_lengths, encoder_padding_mask def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 804f52f8f..7778b7187 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -252,8 +252,6 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con embed_dim = args.encoder_embed_dim self.max_source_positions = args.max_source_positions - self.embed_positions = None - self.conv_layers_before = conv_layers_before self.fc0 = Linear(input_size, embed_dim) if input_size != embed_dim else None diff --git a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index aa5a9e4ac..2bfa19093 100644 --- a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -22,8 +22,10 @@ def __init__(self, args, optimizer): super().__init__(args, optimizer) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer.optimizer, patience=0, factor=args.lr_shrink, - threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0]) + self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink, + mode='max' if args.maximize_best_checkpoint_metric else 'min', + threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0] + ) @staticmethod def add_args(parser): diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 003c390f3..dc91673d4 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -229,7 +229,7 @@ def decode_fn(x): wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) - num_sentences += sample['nsentences'] + num_sentences += sample['nsentences'] if 'nsentences' in sample else sample['id'].numel() logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 6e6f3b6a7..84cc4a490 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -258,10 +258,12 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc args.save_interval_updates > 0 and num_updates > 0 and num_updates % args.save_interval_updates == 0 + and num_updates >= args.validate_after_updates ) or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + or (args.validate_interval_updates > 0 and num_updates % args.validate_interval_updates == 0) ) and not args.disable_validation # Validate diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index e9ea648e3..a5d031d0f 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -263,7 +263,10 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): unk_count += (tgt_dataset[i][0] == self.tgt_dict.unk()).int().sum().item() self.tgt_dict.count[self.tgt_dict.unk()] = unk_count - def build_generator(self, models, args, seq_gen_cls=None): + def build_generator( + self, models, args, + seq_gen_cls=None, extra_gen_cls_kwargs=None + ): if getattr(args, "score_reference", False): args.score_reference = False logger.warning( @@ -324,6 +327,9 @@ def build_generator(self, models, args, seq_gen_cls=None): if seq_gen_cls is None: seq_gen_cls = SequenceGenerator + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + extra_gen_cls_kwargs["lm_weight"] = getattr(args, "lm_weight", 0.0) + extra_gen_cls_kwargs["eos_factor"] = getattr(args, "eos_factor", None) return seq_gen_cls( models, @@ -339,8 +345,7 @@ def build_generator(self, models, args, seq_gen_cls=None): match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, - lm_weight=getattr(args, "lm_weight", 0.0), - eos_factor=getattr(args, "eos_factor", None), + **extra_gen_cls_kwargs, ) def build_dataset_for_inference(self, src_tokens, src_lengths): diff --git a/examples/asr_librispeech/path.sh b/examples/asr_librispeech/path.sh index 863f5de3e..f990ad059 100644 --- a/examples/asr_librispeech/path.sh +++ b/examples/asr_librispeech/path.sh @@ -1,5 +1,5 @@ MAIN_ROOT=$PWD/../.. -KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi +export KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi # BEGIN from kaldi path.sh [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh diff --git a/examples/asr_swbd/path.sh b/examples/asr_swbd/path.sh index 863f5de3e..f990ad059 100644 --- a/examples/asr_swbd/path.sh +++ b/examples/asr_swbd/path.sh @@ -1,5 +1,5 @@ MAIN_ROOT=$PWD/../.. -KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi +export KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi # BEGIN from kaldi path.sh [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh diff --git a/examples/asr_wsj/path.sh b/examples/asr_wsj/path.sh index a190accd1..09e9b0182 100644 --- a/examples/asr_wsj/path.sh +++ b/examples/asr_wsj/path.sh @@ -1,5 +1,5 @@ MAIN_ROOT=$PWD/../.. -KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi +export KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi # BEGIN from kaldi path.sh [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index fc2a15f2c..ebea3b20e 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -36,7 +36,7 @@ def __init__( symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, - eos_factor=None, + **kwargs, ): """Generates translations of a given source sentence. @@ -64,6 +64,7 @@ def __init__( if isinstance(models, EnsembleModel): self.model = models else: + lm_weight = kwargs.get("lm_weight", 0.0) self.model = EnsembleModel(models) if lm_weight == 0.0 else LMFusionModel(models, lm_weight) self.tgt_dict = tgt_dict self.pad = tgt_dict.pad() @@ -88,9 +89,9 @@ def __init__( self.temperature = temperature self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size - self.eos_factor = eos_factor + self.eos_factor = kwargs.get("eos_factor", None) assert temperature > 0, "--temperature must be greater than 0" - assert eos_factor is None or eos_factor >= 1.0, "--eos-factor must be >= 1.0 if set" + assert self.eos_factor is None or self.eos_factor >= 1.0, "--eos-factor must be >= 1.0 if set" self.search = ( search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy From c73c4de20bbbeb0f2609ccb86afa89b645a521c6 Mon Sep 17 00:00:00 2001 From: Yiming Wang Date: Sun, 9 Aug 2020 20:44:46 -0400 Subject: [PATCH 094/119] updates for new PyChain (#37) * add support for output l2 regularization and xent regularization; add a bichar WSJ recipe; add missing soft links to kaldi files * move ChainLossFunction here from PyChain --- espresso/criterions/lf_mmi_loss.py | 161 ++++++++++-- espresso/data/asr_chain_dataset.py | 2 +- examples/asr_wsj/local/data_prep_char.sh | 98 +++++++ .../asr_wsj/local/wsj_extend_char_dict.sh | 1 + examples/asr_wsj/local/wsj_format_data.sh | 1 + .../asr_wsj/local/wsj_format_local_lms.sh | 1 + .../asr_wsj/local/wsj_prepare_char_dict.sh | 1 + examples/asr_wsj/local/wsj_prepare_dict.sh | 1 + examples/asr_wsj/local/wsj_train_lms.sh | 1 + examples/asr_wsj/run_chain_e2e.sh | 7 +- examples/asr_wsj/run_chain_e2e_bichar.sh | 240 ++++++++++++++++++ 11 files changed, 490 insertions(+), 24 deletions(-) create mode 100755 examples/asr_wsj/local/data_prep_char.sh create mode 120000 examples/asr_wsj/local/wsj_extend_char_dict.sh create mode 120000 examples/asr_wsj/local/wsj_format_data.sh create mode 120000 examples/asr_wsj/local/wsj_format_local_lms.sh create mode 120000 examples/asr_wsj/local/wsj_prepare_char_dict.sh create mode 120000 examples/asr_wsj/local/wsj_prepare_dict.sh create mode 120000 examples/asr_wsj/local/wsj_train_lms.sh create mode 100755 examples/asr_wsj/run_chain_e2e_bichar.sh diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index 5fdd74d49..6c04ee01e 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -3,32 +3,132 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import math +import torch + from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.logging import metrics +logger = logging.getLogger(__name__) + + +class ChainLossFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, input_lengths, num_graphs, den_graphs, leaky_coefficient=1e-5): + try: + import pychain_C + except ImportError: + raise ImportError( + "Please install OpenFST and PyChain by `make openfst pychain` " + "after entering espresso/tools" + ) + + input = input.clamp(-30, 30) # clamp for both the denominator and the numerator + B = input.size(0) + if B != num_graphs.batch_size or B != den_graphs.batch_size: + raise ValueError( + "input batch size ({}) does not equal to num graph batch size ({}) " + "or den graph batch size ({})" + .format(B, num_graphs.batch_size, den_graphs.batch_size) + ) + packed_data = torch.nn.utils.rnn.pack_padded_sequence( + input, input_lengths, batch_first=True, + ) + batch_sizes = packed_data.batch_sizes + input_lengths = input_lengths.cpu() + + exp_input = input.exp() + den_objf, input_grad, denominator_ok = pychain_C.forward_backward( + den_graphs.forward_transitions, + den_graphs.forward_transition_indices, + den_graphs.forward_transition_probs, + den_graphs.backward_transitions, + den_graphs.backward_transition_indices, + den_graphs.backward_transition_probs, + den_graphs.leaky_probs, + den_graphs.initial_probs, + den_graphs.final_probs, + den_graphs.start_state, + exp_input, + batch_sizes, + input_lengths, + den_graphs.num_states, + leaky_coefficient, + ) + denominator_ok = denominator_ok.item() + + assert num_graphs.log_domain + num_objf, log_probs_grad, numerator_ok = pychain_C.forward_backward_log_domain( + num_graphs.forward_transitions, + num_graphs.forward_transition_indices, + num_graphs.forward_transition_probs, + num_graphs.backward_transitions, + num_graphs.backward_transition_indices, + num_graphs.backward_transition_probs, + num_graphs.initial_probs, + num_graphs.final_probs, + num_graphs.start_state, + input, + batch_sizes, + input_lengths, + num_graphs.num_states, + ) + numerator_ok = numerator_ok.item() + + loss = -num_objf + den_objf + + if (loss - loss) != 0.0 or not denominator_ok or not numerator_ok: + default_loss = 10 + input_grad = torch.zeros_like(input) + logger.warning( + f"Loss is {loss} and denominator computation " + f"(if done) returned {denominator_ok} " + f"and numerator computation returned {numerator_ok} " + f", setting loss to {default_loss} per frame" + ) + loss = torch.full_like(num_objf, default_loss * input_lengths.sum()) + else: + num_grad = log_probs_grad.exp() + input_grad -= num_grad + + ctx.save_for_backward(input_grad) + return loss + + @staticmethod + def backward(ctx, objf_grad): + input_grad, = ctx.saved_tensors + input_grad = torch.mul(input_grad, objf_grad) + + return input_grad, None, None, None, None + + @register_criterion("lattice_free_mmi") class LatticeFreeMMICriterion(FairseqCriterion): def __init__( self, task, sentence_avg, denominator_fst_path, - den_leaky_hmm_coefficient, num_leaky_hmm_coefficient, + leaky_hmm_coefficient, xent_regularize, output_l2_regularize, ): super().__init__(task) try: from pychain.graph import ChainGraph import simplefst except ImportError: - raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") + raise ImportError( + "Please install OpenFST and PyChain by `make openfst pychain` " + "after entering espresso/tools" + ) self.sentence_avg = sentence_avg den_fst = simplefst.StdVectorFst.read(denominator_fst_path) - self.den_graph = ChainGraph(den_fst, leaky_mode="transition") - self.den_leaky_hmm_coefficient = den_leaky_hmm_coefficient - self.num_leaky_hmm_coefficient = num_leaky_hmm_coefficient + self.den_graph = ChainGraph(den_fst, initial_mode="leaky", final_mode="ones") + self.leaky_hmm_coefficient = leaky_hmm_coefficient + self.xent_regularize = xent_regularize + self.output_l2_regularize = output_l2_regularize @staticmethod def add_args(parser): @@ -37,10 +137,14 @@ def add_args(parser): FairseqCriterion.add_args(parser) parser.add_argument("--denominator-fst-path", type=str, metavar="FILE", help="path to the denominator fst file") - parser.add_argument("--den-leaky-hmm-coefficient", default=1.0e-05, type=float, metavar="F", + parser.add_argument("--leaky-hmm-coefficient", default=1.0e-05, type=float, metavar="F", help="leaky-hmm coefficient for the denominator") - parser.add_argument("--num-leaky-hmm-coefficient", default=1.0e-15, type=float, metavar="F", - help="leaky-hmm coefficient for the numerator") + parser.add_argument("--xent-regularization-coefficient", default=0.0, + type=float, metavar="F", dest="xent_regularize", + help="cross-entropy regularization coefficient") + parser.add_argument("--output-l2-regularization-coefficient", default=0.0, + type=float, metavar="F", dest="output_l2_regularize", + help="L2 regularization coefficient for the network's output") # fmt: on def forward(self, model, sample, reduce=True): @@ -52,11 +156,12 @@ def forward(self, model, sample, reduce=True): 3) logging outputs to display while training """ net_output = model(**sample["net_input"]) - loss, _ = self.compute_loss(net_output, sample, reduce=reduce) + loss, nll_loss = self.compute_loss(net_output, sample, reduce=reduce) sample_size = sample["target"].batch_size if self.sentence_avg else sample["ntokens"] logging_output = { "loss": loss.data, + "nll_loss": nll_loss.data, "ntokens": sample["ntokens"], "nsentences": sample["nsentences"], "sample_size": sample_size, @@ -70,27 +175,45 @@ def compute_loss(self, net_output, sample, reduce=True): except ImportError: raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") - den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"]) encoder_out = net_output.encoder_out.transpose(0, 1) # T x B x V -> B x T x V out_lengths = net_output.src_lengths.long() # B - den_objf = ChainFunction.apply(encoder_out, out_lengths, den_graphs, self.den_leaky_hmm_coefficient) - num_objf = ChainFunction.apply(encoder_out, out_lengths, sample["target"], self.num_leaky_hmm_coefficient) - loss = - num_objf + den_objf # negative log-probs - return loss, loss + den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"]) + if self.xent_regularize > 0.0: + den_objf = ChainFunction.apply(encoder_out, out_lengths, den_graphs, self.leaky_hmm_coefficient) + num_objf = ChainFunction.apply(encoder_out, out_lengths, sample["target"]) + loss = - num_objf + den_objf # negative log-probs + nll_loss = loss.clone().detach() + loss -= self.xent_regularize * num_objf + else: + # demonstrate another more "integrated" usage of the PyChain loss. it's equivalent to + # the first three lines in the above "if" block, but also supports throwing away + # batches with the NaN loss by setting their gradients to 0. + loss = ChainLossFunction.apply( + encoder_out, out_lengths, sample["target"], den_graphs, self.leaky_hmm_coefficient + ) + nll_loss = loss.clone().detach() + + if self.output_l2_regularize > 0.0: + encoder_padding_mask = net_output.encoder_padding_mask + encoder_out_squared = encoder_out.pow(2.0) + if encoder_padding_mask is not None: + pad_mask = encoder_padding_mask.transpose(0, 1).unsqueeze(-1) # T x B -> B x T x 1 + encoder_out_squared.masked_fill_(pad_mask, 0.0) + loss += 0.5 * self.output_l2_regularize * encoder_out_squared.sum() + + return loss, nll_loss @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=7) - if sample_size != ntokens: - metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=7) - metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4)) - else: - metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg, round=4)) + metrics.log_scalar("nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=7) + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4)) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index c819e0453..0085aba20 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -110,7 +110,7 @@ def read_fsts(self, utt_ids: List[str], rxfiles: List[str]): for i, rxfile in enumerate(rxfiles): file_path, offset = self._parse_rxfile(rxfile) fst = simplefst.StdVectorFst.read_ark(file_path, offset) - graph = ChainGraph(fst, leaky_mode="uniform") + graph = ChainGraph(fst, initial_mode="fst", final_mode="fst", log_domain=True) if not graph.is_empty: # skip empty graphs self.utt_ids.append(utt_ids[i]) self.rxfiles.append(rxfile) diff --git a/examples/asr_wsj/local/data_prep_char.sh b/examples/asr_wsj/local/data_prep_char.sh new file mode 100755 index 000000000..c35487b91 --- /dev/null +++ b/examples/asr_wsj/local/data_prep_char.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# Copyright (c) Yiwen Shao, Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The common data preparation script for hybrid systems + +set -euo pipefail + +stage=-10 +nj=30 +train_set=train_si284 +test_set="test_dev93 test_eval92" + +wsj0= +wsj1= +if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then + wsj0=/export/corpora5/LDC/LDC93S6B + wsj1=/export/corpora5/LDC/LDC94S13B +fi + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + + +if [ $stage -le -4 ]; then + # data preparation + [[ -d data/local/data ]] || \ + local/wsj_data_prep.sh $wsj0/??-{?,??}.? $wsj1/??-{?,??}.? + [[ -f data/local/dict_nosp/lexicon.txt ]] || \ + local/wsj_prepare_dict.sh --dict-suffix "_nosp" + + local/wsj_prepare_char_dict.sh + utils/prepare_lang.sh data/local/dict_char \ + "" data/local/lang_tmp_char data/lang_char + local/wsj_format_data.sh --lang-suffix "_char" + echo "Done formatting the data & lang." + + local/wsj_extend_char_dict.sh $wsj1/13-32.1 data/local/dict_char \ + data/local/dict_char_larger + utils/prepare_lang.sh data/local/dict_char_larger \ + "" data/local/lang_larger_tmp \ + data/lang_char_bd + local/wsj_train_lms.sh --dict-suffix "_char" + local/wsj_format_local_lms.sh --lang-suffix "_char" + echo "Done exteding the vocabulary." +fi + +if [ $stage -le -3 ]; then + # make MFCC features for the test data + if [ -f data/test_eval92_hires/feats.scp ]; then + echo "$0: It seems that features for the test sets already exist." + echo "skipping this stage..." + else + echo "$0: extracting MFCC features for the test sets" + for dataset in $test_set; do + mv data/$dataset data/${dataset}_hires + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj \ + --mfcc-config conf/mfcc_hires.conf data/${dataset}_hires + steps/compute_cmvn_stats.sh data/${dataset}_hires + done + fi +fi + +if [ $stage -le -2 ]; then + if [ -f data/${train_set}_sp_hires/feats.scp ]; then + echo "$0: It seems that features for the perturbed training data already exist." + echo "If you want to extract them anyway, remove them first and run this" + echo "stage again. Skipping this stage..." + else + echo "$0: perturbing the training data" + utils/data/get_utt2dur.sh data/$train_set + utils/data/perturb_data_dir_speed_3way.sh data/${train_set} data/${train_set}_sp + utils/copy_data_dir.sh data/${train_set}_sp data/${train_set}_sp_hires + + # do volume-perturbation on the training data prior to extracting hires + # features; this helps make trained nnets more invariant to test data volume. + utils/data/perturb_data_dir_volume.sh data/${train_set}_sp_hires + fi +fi + +if [ $stage -le -1 ]; then + if [ -f data/${train_set}_sp_hires/feats.scp ]; then + echo "$0: It seems that features for the perturbed training data already exist." + echo "If you want to extract them anyway, remove them first and run this" + echo "stage again. Skipping this stage..." + else + echo "$0: extracting MFCC features for the training data" + steps/make_mfcc.sh --cmd "$train_cmd" --nj $nj \ + --mfcc-config conf/mfcc_hires.conf data/${train_set}_sp_hires + steps/compute_cmvn_stats.sh data/${train_set}_sp_hires + utils/fix_data_dir.sh data/${train_set}_sp_hires + fi +fi + +exit 0; diff --git a/examples/asr_wsj/local/wsj_extend_char_dict.sh b/examples/asr_wsj/local/wsj_extend_char_dict.sh new file mode 120000 index 000000000..c43a6df30 --- /dev/null +++ b/examples/asr_wsj/local/wsj_extend_char_dict.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_extend_char_dict.sh \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_format_data.sh b/examples/asr_wsj/local/wsj_format_data.sh new file mode 120000 index 000000000..12930dcbf --- /dev/null +++ b/examples/asr_wsj/local/wsj_format_data.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_format_data.sh \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_format_local_lms.sh b/examples/asr_wsj/local/wsj_format_local_lms.sh new file mode 120000 index 000000000..eb5126eb4 --- /dev/null +++ b/examples/asr_wsj/local/wsj_format_local_lms.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_format_local_lms.sh \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_prepare_char_dict.sh b/examples/asr_wsj/local/wsj_prepare_char_dict.sh new file mode 120000 index 000000000..47c36f3eb --- /dev/null +++ b/examples/asr_wsj/local/wsj_prepare_char_dict.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_prepare_char_dict.sh \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_prepare_dict.sh b/examples/asr_wsj/local/wsj_prepare_dict.sh new file mode 120000 index 000000000..f5d47178b --- /dev/null +++ b/examples/asr_wsj/local/wsj_prepare_dict.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_prepare_dict.sh \ No newline at end of file diff --git a/examples/asr_wsj/local/wsj_train_lms.sh b/examples/asr_wsj/local/wsj_train_lms.sh new file mode 120000 index 000000000..3350c8a68 --- /dev/null +++ b/examples/asr_wsj/local/wsj_train_lms.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_train_lms.sh \ No newline at end of file diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh index b1aa898ca..51a550ff7 100755 --- a/examples/asr_wsj/run_chain_e2e.sh +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -190,14 +190,13 @@ if [ ${stage} -le 6 ]; then --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --max-sentences 128 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --max-sentences-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus \ - --max-epoch 26 --optimizer adam --lr 0.001 --weight-decay 0.0 --start-reduce-lr-epoch 11 \ + --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 --start-reduce-lr-epoch 11 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((400/ngpus/update_freq)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --arch speech_tdnn_wsj --criterion lattice_free_mmi --num-targets $num_targets \ - --dropout 0.2 --kernel-sizes "[3]*6" --strides "[1]*5+[3]" --dilations "[1,1,1,3,3,3]" --num-layers 6 --residual True \ - --denominator-fst-path $tree_dir/normalization.fst \ - --den-leaky-hmm-coefficient 1e-03 --num-leaky-hmm-coefficient 1e-20 \ + --dropout 0.2 --kernel-sizes "[3]*6" --strides "[1]*5+[3]" --dilations "[1,1,1,3,3,3]" --num-layers 6 \ + --denominator-fst-path $tree_dir/den.fst --leaky-hmm-coefficient 1e-03 \ --max-source-positions 9999 --max-target-positions 9999 2>&1 | tee $log_file fi diff --git a/examples/asr_wsj/run_chain_e2e_bichar.sh b/examples/asr_wsj/run_chain_e2e_bichar.sh new file mode 100755 index 000000000..4192a7103 --- /dev/null +++ b/examples/asr_wsj/run_chain_e2e_bichar.sh @@ -0,0 +1,240 @@ +#!/bin/bash +# Copyright (c) Yiming Wang, Yiwen Shao +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +set -e -o pipefail + +stage=-10 +ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid + +# model and data related +affix= +lang=data/lang_chain_e2e_char +tree_dir=exp/chain/e2e_bichar_tree # it's actually just a trivial tree (no tree building) +whole_train_set=train_si284_sp # will be split into train_set and valid_set +train_set=train_si284_novalid_spe2e +valid_set=train_si284_valid_spe2e +test_set="test_dev93 test_eval92" +dumpdir=data/dump # directory to dump full features +checkpoint=checkpoint_best.pt + +wsj0= +wsj1= +if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then + wsj0=/export/corpora5/LDC/LDC93S6B + wsj1=/export/corpora5/LDC/LDC94S13B +fi + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +dir=exp/tdnn_chain_e2e_bichar${affix:+_$affix} + +local/data_prep_char.sh --stage $stage --wsj0 $wsj0 --wsj1 $wsj1 || exit 1; + +if [ $stage -le 0 ]; then + echo "Stage 0: Create the $lang Directory that Has a Specific HMM Topolopy" + rm -rf $lang + cp -r data/lang_char $lang + silphonelist=$(cat $lang/phones/silence.csl) || exit 1; + nonsilphonelist=$(cat $lang/phones/nonsilence.csl) || exit 1; + steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >$lang/topo +fi + +if [ $stage -le 1 ]; then + echo "Stage 1: Generate Denominator Graph and Numerator Fsts" + echo "$0: Estimating a phone language model for the denominator graph..." + mkdir -p $tree_dir/log + $train_cmd $tree_dir/log/make_phone_lm.log \ + cat data/${whole_train_set}_hires/text \| \ + steps/nnet3/chain/e2e/text_to_phones.py --between-silprob 0.1 \ + data/lang_char \| \ + utils/sym2int.pl -f 2- data/lang_char/phones.txt \| \ + chain-est-phone-lm --num-extra-lm-states=2000 \ + ark:- $tree_dir/phone_lm.fst + nj=32 + steps/nnet3/chain/e2e/prepare_e2e.sh --nj $nj --cmd "$train_cmd" \ + --type biphone --shared-phones true --tie true data/${whole_train_set}_hires $lang $tree_dir + echo "$0: Making denominator fst..." + $decode_cmd $tree_dir/log/make_den_fst.log \ + chain-make-den-fst $tree_dir/tree $tree_dir/0.trans_mdl $tree_dir/phone_lm.fst \ + $tree_dir/den.fst $tree_dir/normalization.fst || exit 1 + echo "$0: Making numerator fsts..." + abs_treedir=`utils/make_absolute.sh $tree_dir` + $decode_cmd JOB=1:$nj $tree_dir/log/make_num_fst_e2e.JOB.log \ + chain-make-num-fst-e2e $tree_dir/0.trans_mdl $tree_dir/normalization.fst \ + scp:$tree_dir/fst.JOB.scp ark,scp:$abs_treedir/fst_nor.JOB.ark,$abs_treedir/fst_nor.JOB.scp || exit 1 + for n in $(seq $nj); do + cat $tree_dir/fst_nor.$n.scp || exit 1 + done > $tree_dir/fst_nor.scp || exit 1 +fi + +if [ ${stage} -le 2 ]; then + echo "Stage 2: Split the Whole Train Set into Train/Valid Set" + # Get list of validation utterances. + data=data/${whole_train_set}_hires + set +e + awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl 2>/dev/null | head -300 > valid_uttlist + set -e + if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. + echo "File $data/utt2uniq exists, so augmenting valid_uttlist to" + echo "include all perturbed versions of the same 'real' utterances." + mv valid_uttlist valid_uttlist.tmp + utils/utt2spk_to_spk2utt.pl $data/utt2uniq > uniq2utt + cat valid_uttlist.tmp | utils/apply_map.pl $data/utt2uniq | \ + sort | uniq | utils/apply_map.pl uniq2utt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' | sort > valid_uttlist + rm uniq2utt valid_uttlist.tmp 2>/dev/null + fi + # generate train/valid data dir + utils/filter_scp.pl --exclude valid_uttlist $data/utt2spk | cut -d" " -f1 > novalid_uttlist || exit 1 + utils/subset_data_dir.sh --utt-list novalid_uttlist $data data/${train_set}_hires || exit 1 + utils/subset_data_dir.sh --utt-list valid_uttlist $data data/${valid_set}_hires || exit 1 + + # generate train/valid numerator fst file + utils/filter_scp.pl novalid_uttlist $tree_dir/fst_nor.scp > $tree_dir/fst_novalid_nor.scp || exit 1 + utils/filter_scp.pl valid_uttlist $tree_dir/fst_nor.scp > $tree_dir/fst_valid_nor.scp || exit 1 + rm valid_uttlist novalid_uttlist 2>/dev/null + + # not all fsts can be generated successfully, just filter out those not having the fst + for dataset in $train_set $valid_set; do + tag=novalid && [[ "$dataset" == "$valid_set" ]] && tag=valid + cp data/${dataset}_hires/feats.scp data/${dataset}_hires/feats.scp.tmp + utils/filter_scp.pl $tree_dir/fst_${tag}_nor.scp data/${dataset}_hires/feats.scp.tmp \ + > data/${dataset}_hires/feats.scp || exit 1 + rm data/${dataset}_hires/feats.scp.tmp 2>/dev/null + utils/fix_data_dir.sh data/${dataset}_hires || exit 1 + done +fi + +if [ ${stage} -le 3 ]; then + echo "Stage 3: Dump Feature" + for dataset in $train_set $valid_set $test_set; do + nj=8 + utils/split_data.sh data/${dataset}_hires $nj + sdata=data/${dataset}_hires/split$nj + mkdir -p $dumpdir/${dataset}_hires; abs_featdir=`utils/make_absolute.sh $dumpdir/${dataset}_hires` + $train_cmd JOB=1:$nj $abs_featdir/log/dump_feature.JOB.log \ + apply-cmvn --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp \ + scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=true --compression-method=2 ark:- \ + ark,scp:$abs_featdir/feats.JOB.ark,$abs_featdir/feats.JOB.scp || exit 1 + for n in $(seq $nj); do + cat $abs_featdir/feats.$n.scp || exit 1 + done > $abs_featdir/feats.scp || exit 1 + rm $abs_featdir/feats.*.scp 2>/dev/null + cat data/${dataset}_hires/utt2num_frames > $abs_featdir/utt2num_frames || exit 1 + cat data/${dataset}_hires/utt2spk > $abs_featdir/utt2spk || exit 1 + done +fi + +if [ ${stage} -le 4 ]; then + echo "Stage 4: Make Graphs" + for lmtype in tgpr bd_tgpr; do + utils/lang/check_phones_compatible.sh \ + data/lang_char_test_$lmtype/phones.txt $lang/phones.txt + utils/mkgraph.sh --self-loop-scale 1.0 data/lang_char_test_$lmtype $tree_dir $tree_dir/graph_$lmtype || exit 1 + done +fi + +if [ ${stage} -le 5 ]; then + echo "Stage 5: Dump Json Files" + train_feat=$dumpdir/${train_set}_hires/feats.scp + train_fst=${tree_dir}/fst_novalid_nor.scp + train_text=data/${train_set}_hires/text + train_utt2num_frames=data/${train_set}_hires/utt2num_frames + valid_feat=$dumpdir/${valid_set}_hires/feats.scp + valid_fst=${tree_dir}/fst_valid_nor.scp + valid_text=data/${valid_set}_hires/text + valid_utt2num_frames=data/${valid_set}_hires/utt2num_frames + mkdir -p data/chain_e2e_bichar + asr_prep_json.py --feat-files $train_feat --numerator-fst-files $train_fst --text-files $train_text \ + --utt2num-frames-files $train_utt2num_frames --output data/chain_e2e_bichar/train.json + asr_prep_json.py --feat-files $valid_feat --numerator-fst-files $valid_fst --text-files $valid_text \ + --utt2num-frames-files $valid_utt2num_frames --output data/chain_e2e_bichar/valid.json + for dataset in $test_set; do + nj=$(wc -l &1 | tee $log_file +fi + +if [ ${stage} -le 7 ]; then + echo "Stage 7: Decoding" + rm $dir/.error 2>/dev/null || true + queue_opt="--num-threads 4" + path=$dir/$checkpoint + for dataset in $test_set; do + ( + data_affix=$(echo $dataset | sed s/test_//) + nj=$(wc -l $dir/decode_${lmtype}_${data_affix}/lat.JOB.gz" || exit 1 + local/score.sh --cmd "$decode_cmd" data/${dataset}_hires $graph_dir $dir/decode_${lmtype}_${data_affix} || exit 1 + echo $nj > $dir/decode_${lmtype}_${data_affix}/num_jobs + done + steps/lmrescore.sh --cmd "$decode_cmd" --self-loop-scale 1.0 --mode 3 data/lang_char_test_{tgpr,tg} \ + data/${dataset}_hires $dir/decode_{tgpr,tg}_${data_affix} || exit 1 + steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang_char_test_bd_{tgpr,fgconst} \ + data/${dataset}_hires $dir/decode_bd_tgpr_${data_affix}{,_fg} || exit 1 + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 + for dataset in $test_set; do + data_affix=$(echo $dataset | sed s/test_//) + for x in $dir/decode_{tgpr_${data_affix},tg_${data_affix},bd_tgpr_${data_affix},bd_tgpr_${data_affix}_fg}; do + grep WER $x/wer_* | utils/best_wer.sh + done + done +fi From afaef921077c44b31a1a0dc9505a04ea4766a0fe Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 10 Aug 2020 15:07:42 -0400 Subject: [PATCH 095/119] code adaptation/changes according to the commits on Aug 10-18, 2020 --- espresso/data/asr_chain_dataset.py | 32 ++++++++++++++ espresso/data/asr_dataset.py | 32 ++++++++++++++ espresso/data/asr_xent_dataset.py | 32 ++++++++++++++ espresso/speech_train.py | 49 +++++++-------------- espresso/tasks/speech_recognition.py | 5 ++- espresso/tasks/speech_recognition_hybrid.py | 5 ++- 6 files changed, 119 insertions(+), 36 deletions(-) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 0085aba20..163c4fb82 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -350,6 +350,38 @@ def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) + def filter_indices_by_size(self, indices, max_sizes): + """ Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if self.tgt_sizes is None: + ignored = indices[self.src_sizes[indices] > max_src_size] + else: + ignored = indices[(self.src_sizes[indices] > max_src_size) | + (self.tgt_sizes[indices] > max_tgt_size)] + if len(ignored) > 0: + if self.tgt_sizes is None: + indices = indices[self.src_sizes[indices] <= max_src_size] + else: + indices = indices[(self.src_sizes[indices] <= max_src_size) & + (self.tgt_sizes[indices] <= max_tgt_size)] + return indices, ignored.tolist() + def set_epoch(self, epoch): super().set_epoch(epoch) self.epoch = epoch diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 59eb818bc..0e5ab3c78 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -331,6 +331,38 @@ def prefetch(self, indices): """Only prefetch src.""" self.src.prefetch(indices) + def filter_indices_by_size(self, indices, max_sizes): + """ Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if self.tgt_sizes is None: + ignored = indices[self.src_sizes[indices] > max_src_size] + else: + ignored = indices[(self.src_sizes[indices] > max_src_size) | + (self.tgt_sizes[indices] > max_tgt_size)] + if len(ignored) > 0: + if self.tgt_sizes is None: + indices = indices[self.src_sizes[indices] <= max_src_size] + else: + indices = indices[(self.src_sizes[indices] <= max_src_size) & + (self.tgt_sizes[indices] <= max_tgt_size)] + return indices, ignored.tolist() + def set_epoch(self, epoch): super().set_epoch(epoch) if hasattr(self.src, 'set_epoch'): diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index b96c60f72..e012620f9 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -583,6 +583,38 @@ def prefetch(self, indices): if self.tgt is not None: self.tgt.prefetch(indices) + def filter_indices_by_size(self, indices, max_sizes): + """ Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if self.tgt_sizes is None: + ignored = indices[self.src_sizes[indices] > max_src_size] + else: + ignored = indices[(self.src_sizes[indices] > max_src_size) | + (self.tgt_sizes[indices] > max_tgt_size)] + if len(ignored) > 0: + if self.tgt_sizes is None: + indices = indices[self.src_sizes[indices] <= max_src_size] + else: + indices = indices[(self.src_sizes[indices] <= max_src_size) & + (self.tgt_sizes[indices] <= max_tgt_size)] + return indices, ignored.tolist() + def set_epoch(self, epoch): super().set_epoch(epoch) self.epoch = epoch diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 84cc4a490..564a5bdd4 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -43,6 +43,7 @@ def main(args): assert ( args.max_tokens is not None or args.max_sentences is not None ), "Must specify batch size either with --max-tokens or --max-sentences" + metrics.reset() np.random.seed(args.seed) @@ -65,8 +66,10 @@ def main(args): model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) + logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) + logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info( - "model {}, criterion {}".format(args.arch, criterion.__class__.__name__) + "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) ) logger.info( "num. model params: {} (num. trained: {})".format( @@ -103,11 +106,6 @@ def main(args): # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) - if args.tpu: - import torch_xla.core.xla_model as xm - - xm.rendezvous("load_checkpoint") # wait for all workers - xm.mark_step() # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -161,25 +159,9 @@ def is_better(a, b): return False -def tpu_data_loader(args, itr): - import torch_xla.core.xla_model as xm - import torch_xla.distributed.parallel_loader as pl - - xm.rendezvous("tpu_data_loader") # wait for all workers - xm.mark_step() - device = utils.get_tpu_device(args) - return iterators.CountingIterator( - pl.ParallelLoader(itr, [device]).per_device_loader(device), - start=getattr(itr, "n", 0), - total=len(itr), - ) - - @metrics.aggregate("train") def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" - logger.info("begin training epoch {}".format(epoch_itr.epoch)) - # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, @@ -192,7 +174,7 @@ def train(args, trainer, task, epoch_itr): ) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): - itr = tpu_data_loader(args, itr) + itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, @@ -217,18 +199,17 @@ def train(args, trainer, task, epoch_itr): "train_step-%d" % i ): log_output = trainer.train_step(samples) - if log_output is None: # OOM, overflow, ... - continue - # log mid-epoch stats - num_updates = trainer.get_num_updates() - if num_updates % args.log_interval == 0: - stats = get_training_stats(metrics.get_smoothed_values("train_inner")) - progress.log(stats, tag="train_inner", step=num_updates) + if log_output is not None: # not OOM, overflow, ... + # log mid-epoch stats + num_updates = trainer.get_num_updates() + if num_updates % args.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + progress.log(stats, tag="train_inner", step=num_updates) - # reset mid-epoch stats after each log interval - # the end-of-epoch stats will still be preserved - metrics.reset_meters("train_inner") + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters("train_inner") # update the state prior stored in the model for cross-entropy training if hasattr(task, "update_state_prior"): @@ -309,7 +290,7 @@ def validate(args, trainer, task, epoch_itr, subsets): # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): - itr = tpu_data_loader(args, itr) + itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index a5d031d0f..9a72385b8 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -234,6 +234,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 + if split != getattr(self.args, "train_subset", None): + # if not training data set, use the first shard for valid and test + paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] self.datasets[split] = get_asr_dataset_from_json( @@ -241,7 +244,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, upsample_primary=self.args.upsample_primary, num_buckets=self.args.num_batch_buckets, - shuffle=(split != getattr(self.args, "gen_subset", "")), + shuffle=(split != getattr(self.args, "gen_subset", None)), seed=self.args.seed, specaugment_config=self.specaugment_config, ) diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index ad0f1feb9..b52b70037 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -317,6 +317,9 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): """ paths = utils.split_paths(self.args.data) assert len(paths) > 0 + if split != getattr(self.args, "train_subset", None): + # if not training data set, use the first shard for valid and test + paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] self.datasets[split] = get_asr_dataset_from_json( @@ -324,7 +327,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): combine=combine, upsample_primary=self.args.upsample_primary, num_buckets=self.args.num_batch_buckets, - shuffle=(split != getattr(self.args, "gen_subset", "")), + shuffle=(split != getattr(self.args, "gen_subset", None)), lf_mmi=(self.args.criterion == "lattice_free_mmi"), seed=self.args.seed, specaugment_config=self.specaugment_config, chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, From 5351152deb6adcdd474a9116a4f5ec56ea2ff53a Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 21 Aug 2020 00:20:39 -0400 Subject: [PATCH 096/119] code adaptation/changes according to the commits on Aug 20-24, 2020 --- espresso/data/asr_chain_dataset.py | 4 +- espresso/data/asr_dataset.py | 115 ++++++++++------- espresso/data/asr_dictionary.py | 7 +- espresso/data/asr_xent_dataset.py | 4 +- espresso/speech_recognize.py | 164 ++++++++++++------------ espresso/speech_train.py | 2 +- espresso/tasks/speech_recognition.py | 9 +- espresso/tools/simple_greedy_decoder.py | 3 +- 8 files changed, 167 insertions(+), 141 deletions(-) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 163c4fb82..cce41f00d 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -326,9 +326,9 @@ def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: - indices = np.random.permutation(len(self)) + indices = np.random.permutation(len(self)).astype(np.int64) else: - indices = np.arange(len(self)) + indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 0e5ab3c78..622ee2be0 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -29,79 +29,89 @@ def collate( return {} def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): - if key == 'source': + if key == "source": return speech_utils.collate_frames( [s[key] for s in samples], 0.0, left_pad, pad_to_length=pad_to_length, ) - elif key == 'target' or key == 'prev_output_tokens': + elif key == "target" or key == "prev_output_tokens": return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, pad_to_length=pad_to_length, ) else: - raise ValueError('Invalid key.') + raise ValueError("Invalid key.") - id = torch.LongTensor([s['id'] for s in samples]) + id = torch.LongTensor([s["id"] for s in samples]) src_frames = merge( - 'source', left_pad=left_pad_source, - pad_to_length=pad_to_length['source'] if pad_to_length is not None else None, + "source", left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) # sort by descending source length if pad_to_length is not None or src_bucketed: src_lengths = torch.IntTensor([ - s['source'].ne(0.0).any(dim=1).int().sum() for s in samples + s["source"].ne(0.0).any(dim=1).int().sum() for s in samples ]) else: - src_lengths = torch.IntTensor([s['source'].size(0) for s in samples]) + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) - utt_id = [samples[i]['utt_id'] for i in sort_order.numpy()] + utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()] src_frames = src_frames.index_select(0, sort_order) prev_output_tokens = None target = None - if samples[0].get('target', None) is not None: + if samples[0].get("target", None) is not None: target = merge( - 'target', left_pad=left_pad_target, - pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + "target", left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, ) target = target.index_select(0, sort_order) - ntokens = sum(s['target'].ne(pad_idx).int().sum().item() for s in samples) + ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) - if samples[0].get('prev_output_tokens', None) is not None: - prev_output_tokens = merge('prev_output_tokens', left_pad=left_pad_target) + if samples[0].get("prev_output_tokens", None) is not None: + prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) elif input_feeding: # we create a shifted version of targets for feeding the # previous output token(s) into the next decoder step prev_output_tokens = merge( - 'target', + "target", left_pad=left_pad_target, move_eos_to_beginning=True, - pad_to_length=pad_to_length['target'] if pad_to_length is not None else None, + pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, ) else: ntokens = src_lengths.sum().item() target_raw_text = None - if samples[0].get('target_raw_text', None) is not None: - target_raw_text = [samples[i]['target_raw_text'] for i in sort_order.numpy()] + if samples[0].get("target_raw_text", None) is not None: + target_raw_text = [samples[i]["target_raw_text"] for i in sort_order.numpy()] batch = { - 'id': id, - 'utt_id': utt_id, - 'nsentences': len(samples), - 'ntokens': ntokens, - 'net_input': { - 'src_tokens': src_frames, - 'src_lengths': src_lengths, + "id": id, + "utt_id": utt_id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_frames, + "src_lengths": src_lengths, }, - 'target': target, - 'target_raw_text': target_raw_text, + "target": target, + "target_raw_text": target_raw_text, } if prev_output_tokens is not None: - batch['net_input']['prev_output_tokens'] = prev_output_tokens.index_select(0, sort_order) + batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(0, sort_order) + + if samples[0].get("constraints", None) is not None: + # Collate the packed constraints across the samples, padding to + # the length of the longest sample. + lens = [sample.get("constraints").size(0) for sample in samples] + constraints = torch.zeros((len(samples), max(lens))).long() + for i, sample in enumerate(samples): + constraints[i, 0:lens[i]] = samples[i].get("constraints") + batch["constraints"] = constraints + return batch @@ -123,6 +133,8 @@ class AsrDataset(FairseqDataset): (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). + constraints (Tensor, optional): 2d tensor with a concatenated, zero- + delimited list of constraints for each sentence. num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. src_lang_id (int, optional): source language ID, if set, the collated batch @@ -138,6 +150,7 @@ def __init__( tgt=None, tgt_sizes=None, dictionary=None, left_pad_source=False, left_pad_target=False, shuffle=True, input_feeding=True, + constraints=None, num_buckets=0, src_lang_id=None, tgt_lang_id=None, @@ -146,11 +159,13 @@ def __init__( self.tgt = tgt self.src_sizes = np.array(src_sizes) self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + assert dictionary is not None self.dictionary = dictionary self.left_pad_source = left_pad_source self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding + self.constraints = constraints self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id if self.tgt is not None: @@ -166,7 +181,7 @@ def __init__( left_pad=False, ) self.src_sizes = self.src.sizes - logger.info('bucketing source lengths: {}'.format(list(self.src.buckets))) + logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) if self.tgt is not None: self.tgt = TextBucketPadLengthDataset( self.tgt, @@ -176,7 +191,7 @@ def __init__( left_pad=False, ) self.tgt_sizes = self.tgt.sizes - logger.info('bucketing target lengths: {}'.format(list(self.tgt.buckets))) + logger.info("bucketing target lengths: {}".format(list(self.tgt.buckets))) # determine bucket sizes using self.num_tokens, which will return # the padded lengths (thanks to FeatBucketPadLengthDataset) @@ -205,8 +220,8 @@ def _match_src_tgt(self): tgt_indices = list(map(self.tgt.utt_ids.index, self.src.utt_ids)) except ValueError: raise ValueError( - 'Unable to find some utt_id(s) in tgt. which is unlikely to happen. ' - 'Something must be wrong.' + "Unable to find some utt_id(s) in tgt. which is unlikely to happen. " + "Something must be wrong." ) self.tgt.filter_and_reorder(tgt_indices) self.tgt_sizes = np.array(self.tgt.sizes) @@ -220,12 +235,14 @@ def __getitem__(self, index): raw_text_item = self.tgt[index][1] if self.tgt is not None else None src_item = self.src[index] example = { - 'id': index, - 'utt_id': self.src.utt_ids[index], - 'source': src_item, - 'target': tgt_item, - 'target_raw_text': raw_text_item, + "id": index, + "utt_id": self.src.utt_ids[index], + "source": src_item, + "target": tgt_item, + "target_raw_text": raw_text_item, } + if self.constraints is not None: + example["constraints"] = self.constraints[index] return example def __len__(self): @@ -281,14 +298,14 @@ def collater(self, samples, pad_to_length=None): src_bucketed=(self.buckets is not None), ) if self.src_lang_id is not None or self.tgt_lang_id is not None: - src_tokens = res['net_input']['src_tokens'] + src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: - res['net_input']['src_lang_id'] = torch.LongTensor( + res["net_input"]["src_lang_id"] = torch.LongTensor( [[self.src_lang_id]] ).expand(bsz, 1).to(src_tokens) if self.tgt_lang_id is not None: - res['tgt_lang_id'] = torch.LongTensor( + res["tgt_lang_id"] = torch.LongTensor( [[self.tgt_lang_id]] ).expand(bsz, 1).to(src_tokens) return res @@ -307,25 +324,25 @@ def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: - indices = np.random.permutation(len(self)) + indices = np.random.permutation(len(self)).astype(np.int64) else: - indices = np.arange(len(self)) + indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: indices = indices[ - np.argsort(self.tgt_sizes[indices], kind='mergesort') + np.argsort(self.tgt_sizes[indices], kind="mergesort") ] - return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len return indices[ - np.argsort(self.bucketed_num_tokens[indices], kind='mergesort') + np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") ] @property def supports_prefetch(self): - return getattr(self.src, 'supports_prefetch', False) + return getattr(self.src, "supports_prefetch", False) def prefetch(self, indices): """Only prefetch src.""" @@ -365,7 +382,7 @@ def filter_indices_by_size(self, indices, max_sizes): def set_epoch(self, epoch): super().set_epoch(epoch) - if hasattr(self.src, 'set_epoch'): + if hasattr(self.src, "set_epoch"): self.src.set_epoch(epoch) - if self.tgt is not None and hasattr(self.tgt, 'set_epoch'): + if self.tgt is not None and hasattr(self.tgt, "set_epoch"): self.tgt.set_epoch(epoch) diff --git a/espresso/data/asr_dictionary.py b/espresso/data/asr_dictionary.py index 8dc4610b1..082ec2933 100644 --- a/espresso/data/asr_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -17,18 +17,19 @@ class AsrDictionary(Dictionary): def __init__( self, + bos="", pad="", eos="", unk="", - bos="", space="", extra_special_symbols=None, ): - self.unk_word, self.pad_word, self.eos_word, self.bos_word, self.space_word = \ - unk, pad, eos, bos, space + self.unk_word, self.bos_word, self.pad_word, self.eos_word, self.space_word = \ + unk, bos, pad, eos, space self.symbols = [] self.count = [] self.indices = {} + # no bos added to the dictionary self.pad_index = self.add_symbol(pad, n=0) self.eos_index = self.add_symbol(eos, n=0) self.unk_index = self.add_symbol(unk, n=0) diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index e012620f9..cff0cf099 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -558,9 +558,9 @@ def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: - indices = np.random.permutation(len(self)) + indices = np.random.permutation(len(self)).astype(np.int64) else: - indices = np.arange(len(self)) + indices = np.arange(len(self), dtype=np.int64) if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index dc91673d4..77f8bc0b3 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -30,20 +30,20 @@ def main(args): - assert args.path is not None, '--path required for recognition!' + assert args.path is not None, "--path required for recognition!" assert not args.sampling or args.nbest == args.beam, \ - '--sampling requires --nbest to be equal to --beam' + "--sampling requires --nbest to be equal to --beam" if args.results_path is not None: os.makedirs(args.results_path, exist_ok=True) - output_path = os.path.join(args.results_path, 'decode.log') - with open(output_path, 'w', buffering=1, encoding='utf-8') as h: + output_path = os.path.join(args.results_path, "decode.log") + with open(output_path, "w", buffering=1, encoding="utf-8") as h: return _main(args, h) return _main(args, sys.stdout) def get_symbols_to_strip_from_output(generator): - if hasattr(generator, 'symbols_to_strip_from_output'): + if hasattr(generator, "symbols_to_strip_from_output"): return generator.symbols_to_strip_from_output else: return {generator.eos, generator.pad} @@ -51,12 +51,12 @@ def get_symbols_to_strip_from_output(generator): def _main(args, output_file): logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=output_file, ) - logger = logging.getLogger('espresso.speech_recognize') + logger = logging.getLogger("espresso.speech_recognize") if output_file is not sys.stdout: # also print to stdout logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -83,7 +83,7 @@ def _main(args, output_file): dictionary = task.target_dictionary # Load ensemble - logger.info('loading model(s) from {}'.format(args.path)) + logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), @@ -91,7 +91,7 @@ def _main(args, output_file): suffix=getattr(args, "checkpoint_suffix", ""), ) for i, m in enumerate(models): - if hasattr(m, 'is_wordlm') and m.is_wordlm: + if hasattr(m, "is_wordlm") and m.is_wordlm: # assume subword LM comes before word LM if isinstance(models[i - 1], FairseqLanguageModel): models[i-1] = MultiLevelLanguageModel( @@ -101,19 +101,19 @@ def _main(args, output_file): open_vocab=not args.disable_open_vocab, ) del models[i] - logger.info('LM fusion with Multi-level LM') + logger.info("LM fusion with Multi-level LM") else: models[i] = TensorizedLookaheadLanguageModel( m, dictionary, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) - logger.info('LM fusion with Look-ahead Word LM') + logger.info("LM fusion with Look-ahead Word LM") # assume subword LM comes after E2E models elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel): - logger.info('LM fusion with Subword LM') + logger.info("LM fusion with Subword LM") if args.lm_weight != 0.0: - logger.info('using LM fusion with lm-weight={:.2f}'.format(args.lm_weight)) + logger.info("using LM fusion with lm-weight={:.2f}".format(args.lm_weight)) # Optimize ensemble for generation for model in models: @@ -130,7 +130,7 @@ def _main(args, output_file): max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), - *[model.max_positions() if hasattr(model, 'encoder') + *[model.max_positions() if hasattr(model, "encoder") else (None, model.max_positions()) for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, @@ -143,13 +143,13 @@ def _main(args, output_file): itr, log_format=args.log_format, log_interval=args.log_interval, - default_log_format=('tqdm' if not args.no_progress_bar else 'none'), + default_log_format=("tqdm" if not args.no_progress_bar else "none"), ) # Initialize generator if args.match_source_len: logger.warning( - 'The option match_source_len is not applicable to speech recognition. Ignoring it.' + "The option match_source_len is not applicable to speech recognition. Ignoring it." ) gen_timer = StopwatchMeter() generator = task.build_generator(models, args) @@ -172,55 +172,59 @@ def decode_fn(x): wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample - if 'net_input' not in sample: + if "net_input" not in sample: continue prefix_tokens = None if args.prefix_size > 0: - prefix_tokens = sample['target'][:, :args.prefix_size] + prefix_tokens = sample["target"][:, :args.prefix_size] + + constraints = None + if "constraints" in sample: + constraints = sample["constraints"] gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens) - num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) + hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) # obtain nonpad mask of encoder output to plot attentions if args.print_alignment: - net_input = sample['net_input'] - src_tokens = net_input['src_tokens'] - output_lengths = models[0].encoder.output_lengths(net_input['src_lengths']) + net_input = sample["net_input"] + src_tokens = net_input["src_tokens"] + output_lengths = models[0].encoder.output_lengths(net_input["src_lengths"]) nonpad_idxs = sequence_mask(output_lengths, models[0].encoder.output_lengths(src_tokens.size(1))) - for i in range(len(sample['id'])): - has_target = sample['target'] is not None - utt_id = sample['utt_id'][i] + for i in range(len(sample["id"])): + has_target = sample["target"] is not None + utt_id = sample["utt_id"][i] # Retrieve the original sentences if has_target: - target_str = sample['target_raw_text'][i] + target_str = sample["target_raw_text"][i] if not args.quiet: detok_target_str = decode_fn(target_str) - print('T-{}\t{}'.format(utt_id, detok_target_str), file=output_file) + print("T-{}\t{}".format(utt_id, detok_target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_str = dictionary.string( - hypo['tokens'].int().cpu(), + hypo["tokens"].int().cpu(), bpe_symbol=None, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) # not removing bpe at this point detok_hypo_str = decode_fn(hypo_str) if not args.quiet: - score = hypo['score'] / math.log(2) # convert to base 2 - print('H-{}\t{}\t{}'.format(utt_id, detok_hypo_str, score), file=output_file) + score = hypo["score"] / math.log(2) # convert to base 2 + print("H-{}\t{}\t{}".format(utt_id, detok_hypo_str, score), file=output_file) # Score and obtain attention only the top hypothesis if j == 0: # src_len x tgt_len - attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \ - if args.print_alignment and hypo['attention'] is not None else None + attention = hypo["attention"][nonpad_idxs[i]].float().cpu() \ + if args.print_alignment and hypo["attention"] is not None else None if args.print_alignment and attention is not None: - save_dir = os.path.join(args.results_path, 'attn_plots') + save_dir = os.path.join(args.results_path, "attn_plots") os.makedirs(save_dir, exist_ok=True) plot_attention(attention, detok_hypo_str, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str) @@ -228,50 +232,50 @@ def decode_fn(x): scorer.add_evaluation(utt_id, target_str, hypo_str) wps_meter.update(num_generated_tokens) - progress.log({'wps': round(wps_meter.avg)}) - num_sentences += sample['nsentences'] if 'nsentences' in sample else sample['id'].numel() + progress.log({"wps": round(wps_meter.avg)}) + num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() - logger.info('NOTE: hypothesis and token scores are output in base 2') - logger.info('Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info("Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.print_alignment: - logger.info('Saved attention plots in ' + save_dir) + logger.info("Saved attention plots in " + save_dir) if has_target: scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids) - fn = 'decoded_char_results.txt' - with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: + fn = "decoded_char_results.txt" + with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_char_results()) - logger.info('Decoded char results saved as ' + f.name) + logger.info("Decoded char results saved as " + f.name) - fn = 'decoded_results.txt' - with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: + fn = "decoded_results.txt" + with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_results()) - logger.info('Decoded results saved as ' + f.name) + logger.info("Decoded results saved as " + f.name) if has_target: - header = 'Recognize {} with beam={}: '.format(args.gen_subset, args.beam) - fn = 'wer' - with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: - res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( + header = "Recognize {} with beam={}: ".format(args.gen_subset, args.beam) + fn = "wer" + with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.wer())) logger.info(header + res) - f.write(res + '\n') - logger.info('WER saved in ' + f.name) + f.write(res + "\n") + logger.info("WER saved in " + f.name) - fn = 'cer' - with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: - res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format( + fn = "cer" + with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.cer())) - logger.info(' ' * len(header) + res) - f.write(res + '\n') - logger.info('CER saved in ' + f.name) + logger.info(" " * len(header) + res) + f.write(res + "\n") + logger.info("CER saved in " + f.name) - fn = 'aligned_results.txt' - with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f: + fn = "aligned_results.txt" + with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_aligned_results()) - logger.info('Aligned results saved as ' + f.name) + logger.info("Aligned results saved as " + f.name) return scorer @@ -279,31 +283,31 @@ def print_options_meaning_changes(args, logger): """Options that have different meanings than those in the translation task are explained here. """ - logger.info('--max-tokens is the maximum number of input frames in a batch') + logger.info("--max-tokens is the maximum number of input frames in a batch") if args.print_alignment: - logger.info('--print-alignment has been set to plot attentions') + logger.info("--print-alignment has been set to plot attentions") def cli_main(): - parser = options.get_generation_parser(default_task='speech_recognition_espresso') - parser.add_argument('--eos-factor', default=None, type=float, metavar='F', - help='only consider emitting EOS if its score is no less ' - 'than the specified factor of the best candidate score') - parser.add_argument('--lm-weight', default=0.0, type=float, metavar='W', - help='LM weight in log-prob space, assuming the pretrained ' - 'external LM is specified as the second one in --path') - parser.add_argument('--subwordlm-weight', default=0.8, type=float, metavar='W', - help='subword LM weight relative to word LM. Only relevant ' - 'to MultiLevelLanguageModel as an external LM') - parser.add_argument('--oov-penalty', default=1e-4, type=float, - help='oov penalty with the pretrained external LM') - parser.add_argument('--disable-open-vocab', action='store_true', - help='whether open vocabulary mode is enabled with the ' - 'pretrained external LM') + parser = options.get_generation_parser(default_task="speech_recognition_espresso") + parser.add_argument("--eos-factor", default=None, type=float, metavar="F", + help="only consider emitting EOS if its score is no less " + "than the specified factor of the best candidate score") + parser.add_argument("--lm-weight", default=0.0, type=float, metavar="W", + help="LM weight in log-prob space, assuming the pretrained " + "external LM is specified as the second one in --path") + parser.add_argument("--subwordlm-weight", default=0.8, type=float, metavar="W", + help="subword LM weight relative to word LM. Only relevant " + "to MultiLevelLanguageModel as an external LM") + parser.add_argument("--oov-penalty", default=1e-4, type=float, + help="oov penalty with the pretrained external LM") + parser.add_argument("--disable-open-vocab", action="store_true", + help="whether open vocabulary mode is enabled with the " + "pretrained external LM") args = options.parse_args_and_arch(parser) - assert args.results_path is not None, 'please specify --results-path' + assert args.results_path is not None, "please specify --results-path" main(args) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 564a5bdd4..d83103b14 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -244,7 +244,7 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) - or (args.validate_interval_updates > 0 and num_updates % args.validate_interval_updates == 0) + or (args.validate_interval_updates > 0 and num_updates > 0 and num_updates % args.validate_interval_updates == 0) ) and not args.disable_validation # Validate diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 9a72385b8..c39c5abb0 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -286,6 +286,7 @@ def build_generator( diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) if ( sum( int(cond) @@ -325,6 +326,8 @@ def build_generator( search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch(self.target_dictionary, args.constraints) else: search_strategy = search.BeamSearch(self.target_dictionary) @@ -351,8 +354,10 @@ def build_generator( **extra_gen_cls_kwargs, ) - def build_dataset_for_inference(self, src_tokens, src_lengths): - return AsrDataset(src_tokens, src_lengths) + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + return AsrDataset( + src_tokens, src_lengths, dictionary=self.target_dictionary, constraints=constraints, + ) def build_model(self, args): model = super().build_model(args) diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index 007e9555f..a34bee5f6 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -83,8 +83,7 @@ def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] ) net_input = sample["net_input"] src_tokens = net_input["src_tokens"] - input_size = src_tokens.size() - bsz, src_len = input_size[0], input_size[1] + bsz, src_len = src_tokens.size()[:2] # compute the encoder output encoder_outs = self.model.forward_encoder(net_input) From 9d710781a74cde68ba6716ea3b8ff5e565e53293 Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 31 Aug 2020 16:02:27 -0400 Subject: [PATCH 097/119] code adaptation/changes according to the commits on Aug 31, 2020 --- espresso/dump_posteriors.py | 5 ++-- espresso/models/speech_lstm.py | 4 +-- .../speech_transformer_encoder_model.py | 2 ++ espresso/speech_recognize.py | 2 +- espresso/speech_train.py | 28 +++++++++++++------ 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index 8e90402c3..b21574c91 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -9,6 +9,7 @@ """ import logging +import os import sys import numpy as np @@ -34,7 +35,7 @@ def _main(args, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) logger = logging.getLogger("espresso.dump_posteriors") @@ -195,7 +196,7 @@ def print_options_meaning_changes(args, logger): """Options that have different meanings than those in the translation task are explained here. """ - logger.info("| --max-tokens is the maximum number of input frames in a batch") + logger.info("--max-tokens is the maximum number of input frames in a batch") def cli_main(): diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 670d2663e..5806f338f 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -412,8 +412,8 @@ def forward( packed_x = nn.utils.rnn.pack_padded_sequence( x, ( - src_lengths.data if not self.src_bucketed else - src_lengths.new_full(src_lengths.size(), x.size(0)) + src_lengths.cpu() if not self.src_bucketed else + src_lengths.new_full(src_lengths.size(), x.size(0), device="cpu") ), enforce_sorted=enforce_sorted ) diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index 3ffc2abfa..641e0acc3 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -68,6 +68,8 @@ def add_args(parser): parser.add_argument("--encoder-transformer-context", type=str, metavar="EXPR", help="left/right context for time-restricted self-attention; " "can be None or a tuple of two non-negative integers/None") + parser.add_argument("--no-token-positional-embeddings", action="store_true", + help="if set, disables positional embeddings (outside self attention)") # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument("--encoder-layerdrop", type=float, metavar="D", default=0, help="LayerDrop probability for encoder") diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 77f8bc0b3..f8a3623b7 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -53,7 +53,7 @@ def _main(args, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, ) logger = logging.getLogger("espresso.speech_recognize") diff --git a/espresso/speech_train.py b/espresso/speech_train.py index d83103b14..e7142951c 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -10,6 +10,7 @@ import logging import math +import os import sys import numpy as np @@ -31,7 +32,7 @@ logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, + level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) logger = logging.getLogger("espresso.speech_train") @@ -235,16 +236,26 @@ def train(args, trainer, task, epoch_itr): def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): num_updates = trainer.get_num_updates() + max_update = args.max_update or math.inf do_save = ( - args.save_interval_updates > 0 - and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates >= args.validate_after_updates - ) or (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + or num_updates >= max_update + or ( + args.save_interval_updates > 0 + and num_updates > 0 + and num_updates % args.save_interval_updates == 0 + and num_updates >= args.validate_after_updates + ) + ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) - or (args.validate_interval_updates > 0 and num_updates > 0 and num_updates % args.validate_interval_updates == 0) + or num_updates >= max_update + or ( + args.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % args.validate_interval_updates == 0 + ) ) and not args.disable_validation # Validate @@ -253,10 +264,9 @@ def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoc valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) # Stopping conditions - max_update = args.max_update or math.inf should_stop = ( should_stop_early(args, valid_losses[0]) - or trainer.get_num_updates() >= max_update + or num_updates >= max_update or ( args.stop_time_hours > 0 and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours From e34b27d4530d29fd2a0b4bbeffc0fae138c52056 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 9 Sep 2020 12:30:45 -0400 Subject: [PATCH 098/119] code adaptation/changes according to the commits on Sep 9-11, 2020 --- espresso/data/asr_chain_dataset.py | 4 ++++ espresso/data/asr_dataset.py | 4 ++++ espresso/data/asr_xent_dataset.py | 4 ++++ espresso/dump_posteriors.py | 1 + espresso/models/lstm_lm.py | 6 +++--- espresso/models/speech_lstm.py | 14 +++++++------- espresso/models/speech_lstm_encoder_model.py | 6 +++--- espresso/models/speech_tdnn.py | 4 ++-- espresso/models/speech_transformer.py | 4 ++-- espresso/speech_recognize.py | 1 + espresso/speech_train.py | 9 ++++++++- 11 files changed, 39 insertions(+), 18 deletions(-) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index cce41f00d..557d02c9d 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -382,6 +382,10 @@ def filter_indices_by_size(self, indices, max_sizes): (self.tgt_sizes[indices] <= max_tgt_size)] return indices, ignored.tolist() + @property + def can_reuse_epoch_itr_across_epochs(self): + return False # to avoid running out of CPU RAM + def set_epoch(self, epoch): super().set_epoch(epoch) self.epoch = epoch diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 622ee2be0..418159751 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -380,6 +380,10 @@ def filter_indices_by_size(self, indices, max_sizes): (self.tgt_sizes[indices] <= max_tgt_size)] return indices, ignored.tolist() + @property + def can_reuse_epoch_itr_across_epochs(self): + return False # to avoid running out of CPU RAM + def set_epoch(self, epoch): super().set_epoch(epoch) if hasattr(self.src, "set_epoch"): diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index cff0cf099..2aaee458c 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -615,6 +615,10 @@ def filter_indices_by_size(self, indices, max_sizes): (self.tgt_sizes[indices] <= max_tgt_size)] return indices, ignored.tolist() + @property + def can_reuse_epoch_itr_across_epochs(self): + return False # to avoid running out of CPU RAM + def set_epoch(self, epoch): super().set_epoch(epoch) self.epoch = epoch diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index b21574c91..ed24e4693 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -114,6 +114,7 @@ def _main(args, output_file): num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, diff --git a/espresso/models/lstm_lm.py b/espresso/models/lstm_lm.py index 5639a2196..9be64291f 100644 --- a/espresso/models/lstm_lm.py +++ b/espresso/models/lstm_lm.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq import options, utils +from fairseq import utils from fairseq.models import ( FairseqLanguageModel, register_model, @@ -46,7 +46,7 @@ def add_args(parser): help="comma separated list of adaptive softmax cutoff points. " "Must be used with adaptive_loss criterion") parser.add_argument("--share-embed", - type=lambda x: options.eval_bool(x), + type=lambda x: utils.eval_bool(x), help="share input and output embeddings") parser.add_argument("--is-wordlm", action="store_true", help="whether it is word LM or subword LM. Only " @@ -119,7 +119,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_embed=pretrained_decoder_embed, share_input_output_embed=args.share_embed, adaptive_softmax_cutoff=( - options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == "adaptive_loss" else None ), max_target_positions=max_target_positions, diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 5806f338f..a644b2536 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F -from fairseq import options, utils, checkpoint_utils +from fairseq import utils, checkpoint_utils from fairseq.models import ( FairseqDecoder, FairseqEncoder, @@ -67,10 +67,10 @@ def add_args(parser): parser.add_argument("--encoder-rnn-layers", type=int, metavar="N", help="number of rnn encoder layers") parser.add_argument("--encoder-rnn-bidirectional", - type=lambda x: options.eval_bool(x), + type=lambda x: utils.eval_bool(x), help="make all rnn layers of encoder bidirectional") parser.add_argument("--encoder-rnn-residual", - type=lambda x: options.eval_bool(x), + type=lambda x: utils.eval_bool(x), help="create residual connections for rnn encoder " "layers (starting from the 2nd layer), i.e., the actual " "output of such layer is the sum of its input and output") @@ -87,7 +87,7 @@ def add_args(parser): parser.add_argument("--decoder-out-embed-dim", type=int, metavar="N", help="decoder output embedding dimension") parser.add_argument("--decoder-rnn-residual", - type=lambda x: options.eval_bool(x), + type=lambda x: utils.eval_bool(x), help="create residual connections for rnn decoder " "layers (starting from the 2nd layer), i.e., the actual " "output of such layer is the sum of its input and output") @@ -102,7 +102,7 @@ def add_args(parser): help="comma separated list of adaptive softmax cutoff points. " "Must be used with adaptive_loss criterion") parser.add_argument("--share-decoder-input-output-embed", - type=lambda x: options.eval_bool(x), + type=lambda x: utils.eval_bool(x), help="share decoder input and output embeddings") parser.add_argument("--pretrained-lm-checkpoint", type=str, metavar="STR", help="path to load checkpoint from pretrained language model(LM), " @@ -119,7 +119,7 @@ def add_args(parser): help="dropout probability for decoder output") # Scheduled sampling options - parser.add_argument("--scheduled-sampling-probs", type=lambda p: options.eval_str_list(p), + parser.add_argument("--scheduled-sampling-probs", type=lambda p: utils.eval_str_list(p), metavar="P_1,P_2,...,P_N", default=[1.0], help="scheduled sampling probabilities of sampling the truth " "labels for N epochs starting from --start-schedule-sampling-epoch; " @@ -220,7 +220,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_embed=pretrained_decoder_embed, share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( - options.eval_str_list(args.adaptive_softmax_cutoff, type=int) + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) if args.criterion == "adaptive_loss" else None ), max_target_positions=max_target_positions, diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py index 4940993f8..c7eed3833 100644 --- a/espresso/models/speech_lstm_encoder_model.py +++ b/espresso/models/speech_lstm_encoder_model.py @@ -10,7 +10,7 @@ from torch import Tensor import torch.nn.functional as F -from fairseq import options +from fairseq import utils from fairseq.models import ( FairseqEncoderModel, register_model, @@ -52,10 +52,10 @@ def add_args(parser): parser.add_argument("--encoder-rnn-layers", type=int, metavar="N", help="number of rnn encoder layers") parser.add_argument("--encoder-rnn-bidirectional", - type=lambda x: options.eval_bool(x), + type=lambda x: utils.eval_bool(x), help="make all rnn layers of encoder bidirectional") parser.add_argument("--encoder-rnn-residual", - type=lambda x: options.eval_bool(x), + type=lambda x: utils.eval_bool(x), help="create residual connections for rnn encoder " "layers (starting from the 2nd layer), i.e., the actual " "output of such layer is the sum of its input and output") diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index 4441270d4..8d22a5e2e 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -11,7 +11,7 @@ import torch.nn as nn import torch.nn.functional as F -from fairseq import options +from fairseq import utils from fairseq.models import ( FairseqEncoder, FairseqEncoderModel, @@ -51,7 +51,7 @@ def add_args(parser): help="list of all Tdnn layer\'s dilations") parser.add_argument("--num-layers", type=int, metavar="N", help="number of Tdnn layers") - parser.add_argument("--residual", type=lambda x: options.eval_bool(x), + parser.add_argument("--residual", type=lambda x: utils.eval_bool(x), help="create residual connections for rnn encoder " "layers (starting from the 2nd layer), i.e., the actual " "output of such layer is the sum of its input and output") diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 7778b7187..2f587493d 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -10,7 +10,7 @@ from torch import Tensor import torch.nn as nn -from fairseq import options +from fairseq import utils from fairseq.models import ( register_model, register_model_architecture, @@ -89,7 +89,7 @@ def add_args(parser): "if different from decoder embed dim)") # Scheduled sampling options - parser.add_argument("--scheduled-sampling-probs", type=lambda p: options.eval_str_list(p), + parser.add_argument("--scheduled-sampling-probs", type=lambda p: utils.eval_str_list(p), metavar="P_1,P_2,...,P_N", default=[1.0], help="scheduled sampling probabilities of sampling the truth " "labels for N epochs starting from --start-schedule-sampling-epoch; " diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index f8a3623b7..4279c946e 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -138,6 +138,7 @@ def _main(args, output_file): num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, diff --git a/espresso/speech_train.py b/espresso/speech_train.py index e7142951c..d794f047b 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -106,7 +106,12 @@ def main(args): # Load the latest checkpoint if one is available and restore the # corresponding train iterator - extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) + extra_state, epoch_itr = checkpoint_utils.load_checkpoint( + args, + trainer, + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), + ) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -127,6 +132,8 @@ def main(args): epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) From 3fe02ae0c5fee4fc27644ce332a7f08f1c9ebca3 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 18 Sep 2020 04:12:09 -0400 Subject: [PATCH 099/119] code adaptation/changes according to the commits on Sep 17-26, 2020 --- espresso/criterions/cross_entropy_v2.py | 32 +++++--- .../label_smoothed_cross_entropy_v2.py | 61 ++++++++++----- espresso/criterions/lf_mmi_loss.py | 57 +++++++++----- .../subsampled_cross_entropy_with_accuracy.py | 22 +++++- espresso/data/asr_chain_dataset.py | 39 +++++----- espresso/data/asr_dataset.py | 32 ++++---- espresso/data/asr_dictionary.py | 4 +- espresso/data/asr_xent_dataset.py | 33 ++++----- espresso/models/lstm_lm.py | 74 +++++++++++++++++++ .../lr_scheduler/reduce_lr_on_plateau_v2.py | 53 +++++++++++-- espresso/speech_train.py | 1 + espresso/tasks/language_modeling_for_asr.py | 17 +++-- espresso/tasks/speech_recognition.py | 3 + espresso/tasks/speech_recognition_hybrid.py | 7 +- espresso/tools/utils.py | 4 +- fairseq/sequence_generator.py | 3 +- 16 files changed, 312 insertions(+), 130 deletions(-) diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py index 39f11df00..12b5df138 100644 --- a/espresso/criterions/cross_entropy_v2.py +++ b/espresso/criterions/cross_entropy_v2.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from omegaconf import II import logging import numpy as np @@ -10,33 +12,39 @@ from fairseq import utils from fairseq.criterions import register_criterion -from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig from fairseq.data import data_utils +from fairseq.dataclass.utils import gen_parser_from_dataclass logger = logging.getLogger(__name__) +@dataclass +class CrossEntropyV2CriterionConfig(CrossEntropyCriterionConfig): + print_training_sample_interval: int = field( + default=500, + metadata={ + "help": "print a training sample (reference + prediction) every this number of updates" + }, + ) + + @register_criterion("cross_entropy_v2") class CrossEntropyV2Criterion(CrossEntropyCriterion): - def __init__(self, task, sentence_avg, print_interval): + def __init__(self, task, sentence_avg, print_training_sample_interval): super().__init__(task, sentence_avg) self.dictionary = task.target_dictionary - self.print_interval = print_interval + self.print_interval = print_training_sample_interval self.epoch = 1 self.prev_num_updates = -1 @staticmethod def add_args(parser): - """Add criterion-specific arguments to the parser.""" - # fmt: off - parser.add_argument("--print-training-sample-interval", type=int, - metavar="N", dest="print_interval", default=500, - help="print a training sample (reference + " - "prediction) every this number of updates") - # fmt: on + """Add criterion-specific arguments to the parser. Optionally register config store""" + gen_parser_from_dataclass(parser, CrossEntropyV2CriterionConfig()) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample; periodically print out @@ -49,7 +57,9 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample["net_input"], epoch=self.epoch) loss, _, lprobs = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) logging_output = { "loss": loss.data, "ntokens": sample["ntokens"], diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index 8a522c0e1..eb6e3fa13 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from omegaconf import II import logging import numpy as np @@ -12,11 +14,44 @@ from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from fairseq.data import data_utils +from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES +from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass, gen_parser_from_dataclass logger = logging.getLogger(__name__) +LABEL_SMOOTHING_CHOICES = ChoiceEnum(["uniform", "unigram", "temporal"]) + + +@dataclass +class LabelSmoothedCrossEntropyV2CriterionConfig(FairseqDataclass): + sentence_avg: bool = II("params.optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") + label_smoothing: float = field( + default=0.0, + metadata={ + "help": "epsilon for label smoothing, 0 means no label smoothing" + }, + ) + print_training_sample_interval: int = field( + default=500, + metadata={ + "help": "print a training sample (reference + prediction) every this number of updates" + }, + ) + smoothing_type: LABEL_SMOOTHING_CHOICES = field( + default="uniform", + metadata={"help": "label smoothing type. Default: uniform"}, + ) + unigram_pseudo_count: float = field( + default=1.0, + metadata={ + "help": "pseudo count for unigram label smoothing. Only relevant if --smoothing-type=unigram" + }, + ) + + def temporal_label_smoothing_prob_mask( lprobs: torch.Tensor, # R[Batch, SeqLength, Vocab] target: torch.Tensor, # Z[Batch, SeqLength] @@ -80,14 +115,14 @@ def label_smoothed_nll_loss( class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): def __init__( - self, task, sentence_avg, label_smoothing, smoothing_type, print_interval, - unigram_pseudo_count, + self, task, sentence_avg, label_smoothing, smoothing_type, + print_training_sample_interval, unigram_pseudo_count, ): super().__init__(task, sentence_avg, label_smoothing) self.dictionary = task.target_dictionary self.smoothing_type = smoothing_type - self.print_interval = print_interval + self.print_interval = print_training_sample_interval self.epoch = 1 self.unigram_tensor = None if smoothing_type == "unigram": @@ -98,20 +133,8 @@ def __init__( @staticmethod def add_args(parser): - """Add criterion-specific arguments to the parser.""" - # fmt: off - LabelSmoothedCrossEntropyCriterion.add_args(parser) - parser.add_argument("--print-training-sample-interval", type=int, - metavar="N", dest="print_interval", default=500, - help="print a training sample (reference + " - "prediction) every this number of updates") - parser.add_argument("--smoothing-type", type=str, default="uniform", - choices=["uniform", "unigram", "temporal"], - help="label smoothing type. Default: uniform") - parser.add_argument("--unigram-pseudo-count", type=float, default=1.0, - metavar="C", help="pseudo count for unigram label " - "smoothing. Only relevant if --smoothing-type=unigram") - # fmt: on + """Add criterion-specific arguments to the parser. Optionally register config store""" + gen_parser_from_dataclass(parser, LabelSmoothedCrossEntropyV2CriterionConfig()) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample; periodically print out @@ -126,7 +149,9 @@ def forward(self, model, sample, reduce=True): loss, nll_loss, lprobs = self.compute_loss( model, net_output, sample, reduce=reduce, smoothing_type=self.smoothing_type ) - sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) logging_output = { "loss": loss.data, "nll_loss": nll_loss.data, diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index 6c04ee01e..495c65612 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from omegaconf import II import logging import math @@ -10,12 +12,35 @@ from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES +from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass, gen_parser_from_dataclass from fairseq.logging import metrics logger = logging.getLogger(__name__) +@dataclass +class LatticeFreeMMICriterionConfig(FairseqDataclass): + sentence_avg: bool = II("params.optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") + denominator_fst_path: str = field( + default=None, metadata={"help": "path to the denominator fst file"} + ) + leaky_hmm_coefficient: float = field( + default=1.0e-05, + metadata={"help": "leaky-hmm coefficient for the denominator"}, + ) + xent_regularization_coefficient: float = field( + default=0.0, + metadata={"help": "cross-entropy regularization coefficient"}, + ) + output_l2_regularization_coefficient: float = field( + default=0.0, + metadata={"help": "L2 regularization coefficient for the network's output"}, + ) + + class ChainLossFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, input_lengths, num_graphs, den_graphs, leaky_coefficient=1e-5): @@ -132,20 +157,8 @@ def __init__( @staticmethod def add_args(parser): - """Add criterion-specific arguments to the parser.""" - # fmt: off - FairseqCriterion.add_args(parser) - parser.add_argument("--denominator-fst-path", type=str, metavar="FILE", - help="path to the denominator fst file") - parser.add_argument("--leaky-hmm-coefficient", default=1.0e-05, type=float, metavar="F", - help="leaky-hmm coefficient for the denominator") - parser.add_argument("--xent-regularization-coefficient", default=0.0, - type=float, metavar="F", dest="xent_regularize", - help="cross-entropy regularization coefficient") - parser.add_argument("--output-l2-regularization-coefficient", default=0.0, - type=float, metavar="F", dest="output_l2_regularize", - help="L2 regularization coefficient for the network's output") - # fmt: on + """Add criterion-specific arguments to the parser. Optionally register config store""" + gen_parser_from_dataclass(parser, LatticeFreeMMICriterionConfig()) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -158,7 +171,9 @@ def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) loss, nll_loss = self.compute_loss(net_output, sample, reduce=reduce) - sample_size = sample["target"].batch_size if self.sentence_avg else sample["ntokens"] + sample_size = ( + sample["target"].batch_size if self.sentence_avg else sample["ntokens"] + ) logging_output = { "loss": loss.data, "nll_loss": nll_loss.data, @@ -211,9 +226,15 @@ def reduce_metrics(logging_outputs) -> None: ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=7) - metrics.log_scalar("nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=7) - metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4)) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=7 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=7 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4) + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/espresso/criterions/subsampled_cross_entropy_with_accuracy.py b/espresso/criterions/subsampled_cross_entropy_with_accuracy.py index 8ee620276..3ee5f1699 100644 --- a/espresso/criterions/subsampled_cross_entropy_with_accuracy.py +++ b/espresso/criterions/subsampled_cross_entropy_with_accuracy.py @@ -3,19 +3,26 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass import logging import torch import torch.nn.functional as F from fairseq.criterions import register_criterion -from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig +from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.logging import metrics logger = logging.getLogger(__name__) +@dataclass +class SubsampledCrossEntropyWithAccuracyCriterionConfig(CrossEntropyCriterionConfig): + pass + + @register_criterion("subsampled_cross_entropy_with_accuracy") class SubsampledCrossEntropyWithAccuracyCriterion(CrossEntropyCriterion): @@ -27,6 +34,11 @@ def __init__(self, task, sentence_avg): self.transpose_net_output = getattr(task, "transpose_net_output", True) self.state_prior_update_interval = getattr(task, "state_prior_update_interval", None) + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser. optionaly register config store""" + gen_parser_from_dataclass(parser, SubsampledCrossEntropyWithAccuracyCriterionConfig()) + def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -37,7 +49,9 @@ def forward(self, model, sample, reduce=True): """ net_output = model(**sample["net_input"]) loss, num_corr, num_tot, state_post = self.compute_loss(model, net_output, sample, reduce=reduce) - sample_size = sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) logging_output = { "loss": loss.data, "ntokens": sample["ntokens"], @@ -99,7 +113,9 @@ def reduce_metrics(logging_outputs) -> None: CrossEntropyCriterion.reduce_metrics(logging_outputs) num_corr = sum(log.get("num_corr", 0) for log in logging_outputs) num_tot = sum(log.get("num_tot", 0) for log in logging_outputs) - metrics.log_scalar("accuracy", num_corr.float() / num_tot * 100 if num_tot > 0 else 0.0, num_tot, round=3) + metrics.log_scalar( + "accuracy", num_corr.float() / num_tot * 100 if num_tot > 0 else 0.0, num_tot, round=3 + ) @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 557d02c9d..25cfc4198 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -12,7 +12,7 @@ import torch -from fairseq.data import FairseqDataset +from fairseq.data import data_utils, FairseqDataset import espresso.tools.utils as speech_utils @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -def collate(samples, pad_to_length=None, src_bucketed=False): +def collate(samples, pad_to_length=None, pad_to_multiple=1, src_bucketed=False): try: from pychain import ChainGraphBatch except ImportError: @@ -34,6 +34,7 @@ def merge(key, pad_to_length=None): return speech_utils.collate_frames( [s[key] for s in samples], 0.0, pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, ) elif key == "target": max_num_transitions = max(s["target"].num_transitions for s in samples) @@ -167,11 +168,12 @@ class AsrChainDataset(FairseqDataset): (default: True). num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. + pad_to_multiple (int, optional): pad src lengths to a multiple of this value """ def __init__( self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, shuffle=True, - num_buckets=0, + num_buckets=0, pad_to_multiple=1, ): self.src = src self.tgt = tgt @@ -194,6 +196,7 @@ def __init__( "Removed {} examples due to empty numerator graphs or missing entries, " "{} remaining".format(num_removed, num_after_matching) ) + self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset @@ -217,6 +220,7 @@ def __init__( ] else: self.buckets = None + self.pad_to_multiple = pad_to_multiple def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of @@ -310,7 +314,10 @@ def collater(self, samples, pad_to_length=None): numerator graphs - `text` (List[str]): list of original text """ - return collate(samples, pad_to_length=pad_to_length, src_bucketed=(self.buckets is not None)) + return collate( + samples, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, + src_bucketed=(self.buckets is not None), + ) def num_tokens(self, index): """Return the number of frames in a sample. This value is used to @@ -363,24 +370,12 @@ def filter_indices_by_size(self, indices, max_sizes): np.array: filtered sample array list: list of removed indices """ - if max_sizes is None: - return indices, [] - if type(max_sizes) in (int, float): - max_src_size, max_tgt_size = max_sizes, max_sizes - else: - max_src_size, max_tgt_size = max_sizes - if self.tgt_sizes is None: - ignored = indices[self.src_sizes[indices] > max_src_size] - else: - ignored = indices[(self.src_sizes[indices] > max_src_size) | - (self.tgt_sizes[indices] > max_tgt_size)] - if len(ignored) > 0: - if self.tgt_sizes is None: - indices = indices[self.src_sizes[indices] <= max_src_size] - else: - indices = indices[(self.src_sizes[indices] <= max_src_size) & - (self.tgt_sizes[indices] <= max_tgt_size)] - return indices, ignored.tolist() + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) @property def can_reuse_epoch_itr_across_epochs(self): diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 418159751..5fd41eedc 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -23,6 +23,7 @@ def collate( left_pad_target=False, input_feeding=True, pad_to_length=None, + pad_to_multiple=1, src_bucketed=False, ): if len(samples) == 0: @@ -33,12 +34,14 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): return speech_utils.collate_frames( [s[key] for s in samples], 0.0, left_pad, pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, ) elif key == "target" or key == "prev_output_tokens": return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, ) else: raise ValueError("Invalid key.") @@ -143,6 +146,7 @@ class AsrDataset(FairseqDataset): tgt_lang_id (int, optional): target language ID, if set, the collated batch will contain a field 'tgt_lang_id' which indicates the target language of the samples. + pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value """ def __init__( @@ -154,6 +158,7 @@ def __init__( num_buckets=0, src_lang_id=None, tgt_lang_id=None, + pad_to_multiple=1, ): self.src = src self.tgt = tgt @@ -170,6 +175,7 @@ def __init__( self.tgt_lang_id = tgt_lang_id if self.tgt is not None: self._match_src_tgt() + self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset @@ -203,6 +209,7 @@ def __init__( ] else: self.buckets = None + self.pad_to_multiple = pad_to_multiple def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of @@ -295,6 +302,7 @@ def collater(self, samples, pad_to_length=None): left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, pad_to_length=pad_to_length, + pad_to_multiple=self.pad_to_multiple, src_bucketed=(self.buckets is not None), ) if self.src_lang_id is not None or self.tgt_lang_id is not None: @@ -361,24 +369,12 @@ def filter_indices_by_size(self, indices, max_sizes): np.array: filtered sample array list: list of removed indices """ - if max_sizes is None: - return indices, [] - if type(max_sizes) in (int, float): - max_src_size, max_tgt_size = max_sizes, max_sizes - else: - max_src_size, max_tgt_size = max_sizes - if self.tgt_sizes is None: - ignored = indices[self.src_sizes[indices] > max_src_size] - else: - ignored = indices[(self.src_sizes[indices] > max_src_size) | - (self.tgt_sizes[indices] > max_tgt_size)] - if len(ignored) > 0: - if self.tgt_sizes is None: - indices = indices[self.src_sizes[indices] <= max_src_size] - else: - indices = indices[(self.src_sizes[indices] <= max_src_size) & - (self.tgt_sizes[indices] <= max_tgt_size)] - return indices, ignored.tolist() + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) @property def can_reuse_epoch_itr_across_epochs(self): diff --git a/espresso/data/asr_dictionary.py b/espresso/data/asr_dictionary.py index 082ec2933..2b19b45b7 100644 --- a/espresso/data/asr_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -24,8 +24,8 @@ def __init__( space="", extra_special_symbols=None, ): - self.unk_word, self.bos_word, self.pad_word, self.eos_word, self.space_word = \ - unk, bos, pad, eos, space + self.bos_word, self.unk_word, self.pad_word, self.eos_word, self.space_word = \ + bos, unk, pad, eos, space self.symbols = [] self.count = [] self.indices = {} diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index 2aaee458c..76749e9ba 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -35,6 +35,7 @@ def collate( seed, epoch, pad_to_length=None, + pad_to_multiple=1, src_bucketed=False, random_chunking=True, ): @@ -46,6 +47,7 @@ def merge(key, pad_to_length=None): return speech_utils.collate_frames( [s[key] for s in samples], 0.0, pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, ) elif key == "target": return data_utils.collate_tokens( @@ -53,6 +55,7 @@ def merge(key, pad_to_length=None): pad_idx=pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, ) else: raise ValueError("Invalid key.") @@ -343,6 +346,7 @@ class AsrXentDataset(FairseqDataset): (default: True). num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. + pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value seed (int, optional): random seed for generating a chunk from an utterance. chunk_width (int, optional): chunk width for chunk-wise training. chunk_left_context (int, optional): number of frames appended to the left of a chunk. @@ -355,7 +359,7 @@ class AsrXentDataset(FairseqDataset): def __init__( self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, - shuffle=True, num_buckets=0, seed=1, chunk_width=None, + shuffle=True, num_buckets=0, pad_to_multiple=1, seed=1, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, ): self.src = src @@ -381,6 +385,7 @@ def __init__( changed = self._match_src_text() if self.tgt is not None and changed: self._match_src_tgt() + self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes if chunk_width is not None: # remove those whose lengths are shorter than chunk_size @@ -431,6 +436,7 @@ def __init__( ] else: self.buckets = None + self.pad_to_multiple = pad_to_multiple def _match_src_tgt(self): """Makes utterances in src and tgt the same order in terms of @@ -538,6 +544,7 @@ def collater(self, samples, pad_to_length=None): seed=self.seed, epoch=self.epoch, pad_to_length=pad_to_length, + pad_to_multiple=self.pad_to_multiple, src_bucketed=(self.buckets is not None), random_chunking=self.random_chunking, ) @@ -596,24 +603,12 @@ def filter_indices_by_size(self, indices, max_sizes): np.array: filtered sample array list: list of removed indices """ - if max_sizes is None: - return indices, [] - if type(max_sizes) in (int, float): - max_src_size, max_tgt_size = max_sizes, max_sizes - else: - max_src_size, max_tgt_size = max_sizes - if self.tgt_sizes is None: - ignored = indices[self.src_sizes[indices] > max_src_size] - else: - ignored = indices[(self.src_sizes[indices] > max_src_size) | - (self.tgt_sizes[indices] > max_tgt_size)] - if len(ignored) > 0: - if self.tgt_sizes is None: - indices = indices[self.src_sizes[indices] <= max_src_size] - else: - indices = indices[(self.src_sizes[indices] <= max_src_size) & - (self.tgt_sizes[indices] <= max_tgt_size)] - return indices, ignored.tolist() + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) @property def can_reuse_epoch_itr_across_epochs(self): diff --git a/espresso/models/lstm_lm.py b/espresso/models/lstm_lm.py index 9be64291f..6fc9cca3b 100644 --- a/espresso/models/lstm_lm.py +++ b/espresso/models/lstm_lm.py @@ -3,7 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from omegaconf import II +from typing import Optional + from fairseq import utils +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from fairseq.models import ( FairseqLanguageModel, register_model, @@ -18,6 +23,70 @@ DEFAULT_MAX_TARGET_POSITIONS = 1e5 +@dataclass +class LSTMLanguageModelEspressoConfig(FairseqDataclass): + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + decoder_embed_dim: int = field( + default=48, metadata={"help": "decoder embedding dimension"} + ) + decoder_embed_path: Optional[str] = field( + default=None, metadata={"help": "path to pre-trained decoder embedding"} + ) + decoder_freeze_embed: bool = field( + default=False, metadata={"help": "freeze decoder embeddings"} + ) + decoder_hidden_size: int = field( + default=650, metadata={"help": "decoder hidden size"} + ) + decoder_layers: int = field( + default=2, metadata={"help": "number of decoder layers"} + ) + decoder_out_embed_dim: int = field( + default=650, metadata={"help": "decoder output embedding dimension"} + ) + decoder_rnn_residual: lambda x: utils.eval_bool(x) = field( + default=False, + metadata={ + "help": "create residual connections for rnn decoder layers " + "(starting from the 2nd layer), i.e., the actual output of such " + "layer is the sum of its input and output" + }, + ) + adaptive_softmax_cutoff: Optional[str] = field( + default=None, + metadata={ + "help": "comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion" + }, + ) + share_embed: lambda x: utils.eval_bool(x) = field( + default=False, metadata={"help": "share input and output embeddings"} + ) + is_wordlm: bool = field( + default=False, + metadata={ + "help": "whether it is word LM or subword LM. Only relevant for ASR decoding " + "with LM, and it determines how the underlying decoder instance gets the " + "dictionary from the task instance when calling cls.build_model()" + }, + ) + # Granular dropout settings (if not specified these default to --dropout) + decoder_dropout_in: float = field( + default=0.1, + metadata={"help": "dropout probability for decoder input embedding"} + ) + decoder_dropout_out: float = field( + default=0.1, + metadata={"help": "dropout probability for decoder output"} + ) + # TODO common var add to parent + add_bos_token: bool = II("task.add_bos_token") + tokens_per_sample: int = II("task.tokens_per_sample") + max_target_positions: Optional[int] = II("task.max_target_positions") + # TODO common var add to parent + tpu: bool = II("params.common.tpu") + + @register_model("lstm_lm_espresso") class LSTMLanguageModelEspresso(FairseqLanguageModel): def __init__(self, decoder, args): @@ -42,6 +111,11 @@ def add_args(parser): help="number of decoder layers") parser.add_argument("--decoder-out-embed-dim", type=int, metavar="N", help="decoder output embedding dimension") + parser.add_argument("--decoder-rnn-residual", + type=lambda x: utils.eval_bool(x), + help="create residual connections for rnn decoder " + "layers (starting from the 2nd layer), i.e., the actual " + "output of such layer is the sum of its input and output") parser.add_argument("--adaptive-softmax-cutoff", metavar="EXPR", help="comma separated list of adaptive softmax cutoff points. " "Must be used with adaptive_loss criterion") diff --git a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 2bfa19093..2c9adff1b 100644 --- a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -3,12 +3,57 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from omegaconf import II +from typing import List + import torch.optim.lr_scheduler +from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass from fairseq.optim.lr_scheduler import register_lr_scheduler from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateau +@dataclass +class ReduceLROnPlateauV2Config(FairseqDataclass): + lr_shrink: float = field( + default=0.1, + metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"}, + ) + lr_threshold: float = field( + default=1e-4, + metadata={ + "help": "threshold for measuring the new optimum, to only focus on significant changes" + }, + ) + lr_patience: int = field( + default=0, + metadata={ + "help": "number of epochs with no improvement after which learning rate will be reduced" + }, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is args.lr" + }, + ) + final_lr_scale: float = field( + default=0.01, + metadata={"help": "final learning rate scale; default to 0.01"}, + ) + start_reduce_lr_epoch: int = field( + default=0, + metadata={"help": "start to reduce lr from the specified epoch"}, + ) + # TODO common vars at parent class + lr: List[float] = II("params.optimization.lr") + + @register_lr_scheduler('reduce_lr_on_plateau_v2') class ReduceLROnPlateauV2(ReduceLROnPlateau): """Decay the LR by a factor every time the validation loss plateaus, starting @@ -30,13 +75,7 @@ def __init__(self, args, optimizer): @staticmethod def add_args(parser): """Add arguments to the parser for this LR scheduler.""" - ReduceLROnPlateau.add_args(parser) - # fmt: off - parser.add_argument('--final-lr-scale', default=0.01, type=float, metavar='N', - help='final learning rate scale; default to 0.01') - parser.add_argument('--start-reduce-lr-epoch', default=0, type=int, metavar='N', - help='start to reduce lr from the specified epoch') - # fmt: on + gen_parser_from_dataclass(parser, ReduceLROnPlateauV2Config()) def step(self, epoch, val_loss=None): if epoch < self.args.start_reduce_lr_epoch: diff --git a/espresso/speech_train.py b/espresso/speech_train.py index d794f047b..8c8e7fe2c 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -300,6 +300,7 @@ def validate(args, trainer, task, epoch_itr, subsets): # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) + trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 8f3c299e0..4afef4de0 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -5,13 +5,15 @@ import logging import os +from dataclasses import dataclass, field import torch from fairseq import tokenizer, utils from fairseq.data import TruncatedDictionary +from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.tasks import register_task -from fairseq.tasks.language_modeling import LanguageModelingTask +from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig from espresso.data import AsrDictionary @@ -19,6 +21,11 @@ logger = logging.getLogger(__name__) +@dataclass +class LanguageModelingForASRConfig(LanguageModelingConfig): + dict: str = field(default=None, metadata={"help": "path to the dictionary"}) + + @register_task("language_modeling_for_asr") class LanguageModelingForASRTask(LanguageModelingTask): """ @@ -51,12 +58,8 @@ class LanguageModelingForASRTask(LanguageModelingTask): @staticmethod def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - LanguageModelingTask.add_args(parser) - parser.add_argument('--dict', default=None, type=str, - help='path to the dictionary') - # fmt: on + """Add task-specific arguments to the parser. Optionally register config store""" + gen_parser_from_dataclass(parser, LanguageModelingForASRConfig()) def __init__(self, args, dictionary, output_dictionary=None, targets=None): super().__init__(args, dictionary, output_dictionary, targets=targets) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index c39c5abb0..544162be5 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -31,6 +31,7 @@ def get_asr_dataset_from_json( data_path, split, tgt_dict, combine, upsample_primary, num_buckets=0, shuffle=True, + pad_to_multiple=1, seed=1, specaugment_config=None, ): """ @@ -115,6 +116,7 @@ def get_asr_dataset_from_json( left_pad_target=False, num_buckets=num_buckets, shuffle=shuffle, + pad_to_multiple=pad_to_multiple, ) @@ -245,6 +247,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): upsample_primary=self.args.upsample_primary, num_buckets=self.args.num_batch_buckets, shuffle=(split != getattr(self.args, "gen_subset", None)), + pad_to_multiple=self.args.required_seq_len_multiple, seed=self.args.seed, specaugment_config=self.specaugment_config, ) diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index b52b70037..6eedeaa84 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -38,8 +38,8 @@ def get_asr_dataset_from_json( data_path, split, dictionary, combine, upsample_primary, - num_buckets=0, - shuffle=True, + num_buckets=0, shuffle=True, + pad_to_multiple=1, lf_mmi=True, seed=1, specaugment_config=None, chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, @@ -149,6 +149,7 @@ def get_asr_dataset_from_json( text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, + pad_to_multiple=pad_to_multiple, ) else: return AsrXentDataset( @@ -157,6 +158,7 @@ def get_asr_dataset_from_json( text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, + pad_to_multiple=pad_to_multiple, seed=seed, chunk_width=chunk_width, chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), @@ -328,6 +330,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): upsample_primary=self.args.upsample_primary, num_buckets=self.args.num_batch_buckets, shuffle=(split != getattr(self.args, "gen_subset", None)), + pad_to_multiple=self.args.required_seq_len_multiple, lf_mmi=(self.args.criterion == "lattice_free_mmi"), seed=self.args.seed, specaugment_config=self.specaugment_config, chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index c0550ec74..19fe51ab5 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -38,11 +38,13 @@ def tokenize(sent, space='', non_lang_syms=None): return ' '.join(tokens) -def collate_frames(values, pad_value=0.0, left_pad=False, pad_to_length=None): +def collate_frames(values, pad_value=0.0, left_pad=False, pad_to_length=None, pad_to_multiple=1): """Convert a list of 2d tensor into a padded 3d tensor.""" assert values[0].dim() == 2, "expected 2, got " + str(values[0].dim) length = max(v.size(0) for v in values) length = length if pad_to_length is None else max(length, pad_to_length) + if pad_to_multiple != 1 and length % pad_to_multiple != 0: + length = (length + pad_to_multiple - 1) // pad_to_multiple * pad_to_multiple dim = values[0].size(1) res = values[0].new(len(values), length, dim).fill_(pad_value) diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index ebea3b20e..f94e9f873 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -13,8 +13,6 @@ from fairseq.models import FairseqIncrementalDecoder from torch import Tensor -from espresso.models.external_language_model import RawOutExternalLanguageModelBase - class SequenceGenerator(nn.Module): def __init__( @@ -1082,6 +1080,7 @@ def forward_decoder( None if decoder_len <= 1 else decoder_out[1], ) + from espresso.models.external_language_model import RawOutExternalLanguageModelBase if isinstance(model, RawOutExternalLanguageModelBase): probs = decoder_out_tuple[0] else: From 8c02b4595d7578b4fb35b6bf3cb182ea141049cd Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 1 Oct 2020 16:19:56 -0400 Subject: [PATCH 100/119] code adaptation/changes according to the commits on Oct 1, 2020 --- espresso/models/speech_transformer.py | 7 +++++++ espresso/speech_train.py | 1 + 2 files changed, 8 insertions(+) diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 2f587493d..b25d57ced 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -425,6 +425,7 @@ def forward( encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, + full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, @@ -441,6 +442,8 @@ def forward( :ref:`Incremental decoding` features_only (bool, optional): only return features without applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). Returns: tuple: @@ -457,6 +460,7 @@ def forward( return self._forward_with_scheduled_sampling( prev_output_tokens, sampling_prob, encoder_out=encoder_out, incremental_state={}, # use empty dict to preserve forward state + full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, src_lengths=src_lengths, @@ -467,6 +471,7 @@ def forward( prev_output_tokens, encoder_out=encoder_out, incremental_state=incremental_state, + full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) @@ -481,6 +486,7 @@ def _forward_with_scheduled_sampling( encoder_out: Optional[EncoderOut] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, + full_context_alignment: bool = False, alignment_layer: Optional[int] = None, alignment_heads: Optional[int] = None, src_lengths: Optional[Any] = None, @@ -503,6 +509,7 @@ def _forward_with_scheduled_sampling( feed_tokens, encoder_out=encoder_out, incremental_state=incremental_state, + full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, alignment_heads=alignment_heads, ) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 8c8e7fe2c..26763e92b 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -199,6 +199,7 @@ def train(args, trainer, task, epoch_itr): if hasattr(trainer.criterion, "set_epoch"): trainer.criterion.set_epoch(epoch_itr.epoch) + valid_losses = [None] valid_subsets = args.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() From 707ce3abea2db41cf13e4a21a8805f0c41b5134f Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 11 Oct 2020 19:48:30 -0400 Subject: [PATCH 101/119] code adaptation/changes according to the commits on Oct 2-15, 2020 --- espresso/criterions/__init__.py | 4 +- espresso/criterions/cross_entropy_v2.py | 9 +- .../label_smoothed_cross_entropy_v2.py | 51 ++++--- espresso/criterions/lf_mmi_loss.py | 23 ++-- .../subsampled_cross_entropy_with_accuracy.py | 12 +- espresso/data/asr_chain_dataset.py | 5 + espresso/data/asr_dataset.py | 5 + espresso/data/asr_dictionary.py | 2 +- espresso/data/asr_xent_dataset.py | 5 + espresso/data/encoders/__init__.py | 4 +- espresso/data/feat_text_dataset.py | 48 ++++--- espresso/dump_posteriors.py | 17 ++- espresso/models/__init__.py | 4 +- espresso/models/lstm_lm.py | 90 ++++--------- espresso/models/speech_transformer.py | 7 +- .../speech_transformer_encoder_model.py | 7 +- espresso/optim/__init__.py | 4 +- espresso/optim/lr_scheduler/__init__.py | 4 +- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 16 ++- espresso/speech_recognize.py | 77 ++++++++--- espresso/speech_train.py | 6 +- espresso/tasks/__init__.py | 4 +- espresso/tasks/language_modeling_for_asr.py | 8 +- espresso/tasks/speech_recognition.py | 90 +------------ espresso/tools/wer.py | 6 +- examples/asr_librispeech/run.sh | 14 +- examples/asr_swbd/run.sh | 14 +- examples/asr_wsj/local/wsj_extend_dict.sh | 1 + examples/asr_wsj/run.sh | 22 +-- examples/asr_wsj/run_chain_e2e.sh | 6 +- examples/asr_wsj/run_chain_e2e_bichar.sh | 6 +- examples/asr_wsj/run_xent.sh | 6 +- fairseq/sequence_generator.py | 126 +++--------------- 33 files changed, 282 insertions(+), 421 deletions(-) create mode 120000 examples/asr_wsj/local/wsj_extend_dict.sh diff --git a/espresso/criterions/__init__.py b/espresso/criterions/__init__.py index 3edbc58f4..fac6ceb88 100644 --- a/espresso/criterions/__init__.py +++ b/espresso/criterions/__init__.py @@ -10,5 +10,5 @@ # automatically import any Python files in the criterions/ directory for file in os.listdir(os.path.dirname(__file__)): if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): - criterion_name = file[: file.find(".py")] - importlib.import_module("espresso.criterions." + criterion_name) + file_name = file[: file.find(".py")] + importlib.import_module("espresso.criterions." + file_name) diff --git a/espresso/criterions/cross_entropy_v2.py b/espresso/criterions/cross_entropy_v2.py index 12b5df138..c49d7a701 100644 --- a/espresso/criterions/cross_entropy_v2.py +++ b/espresso/criterions/cross_entropy_v2.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II import logging import numpy as np @@ -14,7 +13,6 @@ from fairseq.criterions import register_criterion from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig from fairseq.data import data_utils -from fairseq.dataclass.utils import gen_parser_from_dataclass logger = logging.getLogger(__name__) @@ -30,7 +28,7 @@ class CrossEntropyV2CriterionConfig(CrossEntropyCriterionConfig): ) -@register_criterion("cross_entropy_v2") +@register_criterion("cross_entropy_v2", dataclass=CrossEntropyV2CriterionConfig) class CrossEntropyV2Criterion(CrossEntropyCriterion): def __init__(self, task, sentence_avg, print_training_sample_interval): @@ -41,11 +39,6 @@ def __init__(self, task, sentence_avg, print_training_sample_interval): self.epoch = 1 self.prev_num_updates = -1 - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser. Optionally register config store""" - gen_parser_from_dataclass(parser, CrossEntropyV2CriterionConfig()) - def forward(self, model, sample, reduce=True): """Compute the loss for the given sample; periodically print out randomly sampled predictions from the training set. diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index eb6e3fa13..cb29b46db 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -14,8 +14,8 @@ from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from fairseq.data import data_utils -from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES -from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass logger = logging.getLogger(__name__) @@ -27,13 +27,24 @@ @dataclass class LabelSmoothedCrossEntropyV2CriterionConfig(FairseqDataclass): sentence_avg: bool = II("params.optimization.sentence_avg") - ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") label_smoothing: float = field( default=0.0, metadata={ "help": "epsilon for label smoothing, 0 means no label smoothing" }, ) + report_accuracy: bool = field( + default=False, + metadata={ + "help": "report accuracy metric" + }, + ) + ignore_prefix_size: bool = field( + default=False, + metadata={ + "help": "ignore first N tokens" + }, + ) print_training_sample_interval: int = field( default=500, metadata={ @@ -111,14 +122,18 @@ def label_smoothed_nll_loss( return loss, nll_loss -@register_criterion("label_smoothed_cross_entropy_v2") +@register_criterion("label_smoothed_cross_entropy_v2", dataclass=LabelSmoothedCrossEntropyV2CriterionConfig) class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): def __init__( self, task, sentence_avg, label_smoothing, smoothing_type, print_training_sample_interval, unigram_pseudo_count, + ignore_prefix_size=0, report_accuracy=False, ): - super().__init__(task, sentence_avg, label_smoothing) + super().__init__( + task, sentence_avg, label_smoothing, + ignore_prefix_size=ignore_prefix_size, report_accuracy=report_accuracy, + ) self.dictionary = task.target_dictionary self.smoothing_type = smoothing_type @@ -131,10 +146,12 @@ def __init__( self.unigram_tensor.div_(self.unigram_tensor.sum()) self.prev_num_updates = -1 - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser. Optionally register config store""" - gen_parser_from_dataclass(parser, LabelSmoothedCrossEntropyV2CriterionConfig()) + @classmethod + def add_args(cls, parser): + """Add criterion-specific arguments to the parser.""" + dc = getattr(cls, '__dataclass', None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample; periodically print out @@ -159,6 +176,10 @@ def forward(self, model, sample, reduce=True): "nsentences": sample["target"].size(0), "sample_size": sample_size, } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) if ( hasattr(model, "num_updates") and model.training and @@ -168,7 +189,7 @@ def forward(self, model, sample, reduce=True): ): # print a randomly sampled result every print_interval updates self.prev_num_updates = model.num_updates target = model.get_targets(sample, net_output) - pred = lprobs.argmax(-1).cpu() # bsz x len + pred = lprobs.view(target.size(0), -1, lprobs.size(-1)).argmax(-1).cpu() # bsz x len assert pred.size() == target.size() with data_utils.numpy_seed(model.num_updates): i = np.random.randint(0, len(sample["id"])) @@ -184,14 +205,14 @@ def forward(self, model, sample, reduce=True): def compute_loss( self, model, net_output, sample, reduce=True, smoothing_type="uniform" ): - lprobs = model.get_normalized_probs(net_output, log_probs=True) - target = model.get_targets(sample, net_output) + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + bsz = sample["target"].size(0) prob_mask = temporal_label_smoothing_prob_mask( - lprobs, target, padding_index=self.padding_idx, + lprobs.view(bsz, -1, lprobs.size(-1)), target.view(bsz, -1), + padding_index=self.padding_idx, ) if smoothing_type == "temporal" else None loss, nll_loss = label_smoothed_nll_loss( - lprobs.view(-1, lprobs.size(-1)), target.view(-1, 1), self.eps, - ignore_index=self.padding_idx, reduce=reduce, + lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, smoothing_type=smoothing_type, prob_mask=prob_mask, unigram_tensor=self.unigram_tensor, ) diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index 495c65612..997b1a422 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -12,8 +12,7 @@ from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion -from fairseq.dataclass.data_class import DDP_BACKEND_CHOICES -from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass from fairseq.logging import metrics @@ -23,7 +22,6 @@ @dataclass class LatticeFreeMMICriterionConfig(FairseqDataclass): sentence_avg: bool = II("params.optimization.sentence_avg") - ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") denominator_fst_path: str = field( default=None, metadata={"help": "path to the denominator fst file"} ) @@ -131,12 +129,12 @@ def backward(ctx, objf_grad): return input_grad, None, None, None, None -@register_criterion("lattice_free_mmi") +@register_criterion("lattice_free_mmi", dataclass=LatticeFreeMMICriterionConfig) class LatticeFreeMMICriterion(FairseqCriterion): def __init__( - self, task, sentence_avg, denominator_fst_path, - leaky_hmm_coefficient, xent_regularize, output_l2_regularize, + self, task, sentence_avg, denominator_fst_path, leaky_hmm_coefficient, + xent_regularization_coefficient, output_l2_regularization_coefficient, ): super().__init__(task) try: @@ -152,13 +150,8 @@ def __init__( den_fst = simplefst.StdVectorFst.read(denominator_fst_path) self.den_graph = ChainGraph(den_fst, initial_mode="leaky", final_mode="ones") self.leaky_hmm_coefficient = leaky_hmm_coefficient - self.xent_regularize = xent_regularize - self.output_l2_regularize = output_l2_regularize - - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser. Optionally register config store""" - gen_parser_from_dataclass(parser, LatticeFreeMMICriterionConfig()) + self.xent_regularize = xent_regularization_coefficient + self.output_l2_regularize = output_l2_regularization_coefficient def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -218,8 +211,8 @@ def compute_loss(self, net_output, sample, reduce=True): return loss, nll_loss - @staticmethod - def reduce_metrics(logging_outputs) -> None: + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs) diff --git a/espresso/criterions/subsampled_cross_entropy_with_accuracy.py b/espresso/criterions/subsampled_cross_entropy_with_accuracy.py index 3ee5f1699..6ad13df44 100644 --- a/espresso/criterions/subsampled_cross_entropy_with_accuracy.py +++ b/espresso/criterions/subsampled_cross_entropy_with_accuracy.py @@ -11,7 +11,6 @@ from fairseq.criterions import register_criterion from fairseq.criterions.cross_entropy import CrossEntropyCriterion, CrossEntropyCriterionConfig -from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.logging import metrics @@ -23,7 +22,7 @@ class SubsampledCrossEntropyWithAccuracyCriterionConfig(CrossEntropyCriterionCon pass -@register_criterion("subsampled_cross_entropy_with_accuracy") +@register_criterion("subsampled_cross_entropy_with_accuracy", dataclass=SubsampledCrossEntropyWithAccuracyCriterionConfig) class SubsampledCrossEntropyWithAccuracyCriterion(CrossEntropyCriterion): def __init__(self, task, sentence_avg): @@ -34,11 +33,6 @@ def __init__(self, task, sentence_avg): self.transpose_net_output = getattr(task, "transpose_net_output", True) self.state_prior_update_interval = getattr(task, "state_prior_update_interval", None) - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser. optionaly register config store""" - gen_parser_from_dataclass(parser, SubsampledCrossEntropyWithAccuracyCriterionConfig()) - def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -107,8 +101,8 @@ def compute_loss(self, model, net_output, sample, reduce=True): return loss, num_corr, num_tot, state_post - @staticmethod - def reduce_metrics(logging_outputs) -> None: + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" CrossEntropyCriterion.reduce_metrics(logging_outputs) num_corr = sum(log.get("num_corr", 0) for log in logging_outputs) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 25cfc4198..2ed83741f 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -377,6 +377,11 @@ def filter_indices_by_size(self, indices, max_sizes): max_sizes, ) + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return False + @property def can_reuse_epoch_itr_across_epochs(self): return False # to avoid running out of CPU RAM diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index 5fd41eedc..b0dc9259a 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -376,6 +376,11 @@ def filter_indices_by_size(self, indices, max_sizes): max_sizes, ) + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return False + @property def can_reuse_epoch_itr_across_epochs(self): return False # to avoid running out of CPU RAM diff --git a/espresso/data/asr_dictionary.py b/espresso/data/asr_dictionary.py index 2b19b45b7..3c0a07912 100644 --- a/espresso/data/asr_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -71,7 +71,7 @@ def load(cls, f, f_non_lang_syms=None): if f_non_lang_syms is not None: assert isinstance(f_non_lang_syms, str) try: - with PathManager.open(f_non_lang_syms, "r", encoding="utf-8") as fd: + with open(PathManager.get_local_path(f_non_lang_syms), "r", encoding="utf-8") as fd: non_lang_syms = [x.rstrip() for x in fd.readlines()] except FileNotFoundError as fnfe: raise fnfe diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index 76749e9ba..e831345f9 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -610,6 +610,11 @@ def filter_indices_by_size(self, indices, max_sizes): max_sizes, ) + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return False + @property def can_reuse_epoch_itr_across_epochs(self): return False # to avoid running out of CPU RAM diff --git a/espresso/data/encoders/__init__.py b/espresso/data/encoders/__init__.py index d8c6cd0fe..c9a579fb5 100644 --- a/espresso/data/encoders/__init__.py +++ b/espresso/data/encoders/__init__.py @@ -11,5 +11,5 @@ # automatically import any Python files in the encoders/ directory for file in os.listdir(os.path.dirname(__file__)): if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): - module = file[:file.find(".py")] - importlib.import_module("espresso.data.encoders." + module) + file_name = file[: file.find(".py")] + importlib.import_module("espresso.data.encoders." + file_name) diff --git a/espresso/data/feat_text_dataset.py b/espresso/data/feat_text_dataset.py index c4c49383b..0dce55976 100644 --- a/espresso/data/feat_text_dataset.py +++ b/espresso/data/feat_text_dataset.py @@ -18,7 +18,7 @@ try: import kaldi_io except ImportError: - raise ImportError('Please install kaldi_io with: pip install kaldi_io') + raise ImportError("Please install kaldi_io with: pip install kaldi_io") class FeatScpDataset(torch.utils.data.Dataset): @@ -48,7 +48,7 @@ def __init__( try: feat = kaldi_io.read_mat(rxfile) except Exception: - raise Exception('failed to read feature matrix {}.'.format(rxfile)) + raise Exception("failed to read feature matrix {}.".format(rxfile)) assert feat is not None and isinstance(feat, np.ndarray) if len(self.sizes) == self.size: break @@ -63,14 +63,13 @@ def __init__( def check_index(self, i): if i < 0 or i >= self.size: - raise IndexError('index out of range') + raise IndexError("index out of range") def filter_and_reorder(self, indices): assert isinstance(indices, (list, np.ndarray)) indices = np.array(indices) assert all(indices < len(self.utt_ids)) and all(indices >= 0) - assert len(np.unique(indices)) == len(indices), \ - 'Duplicate elements in indices.' + assert len(np.unique(indices)) == len(indices), "Duplicate elements in indices." self.utt_ids = [self.utt_ids[i] for i in indices] self.rxfiles = [self.rxfiles[i] for i in indices] self.sizes = self.sizes[indices] @@ -121,6 +120,11 @@ def __init__( # self.ordered_indices, and doing this will speed up search of the # queried index self.ordered_prefetch = ordered_prefetch + # a flag to indicate whether self.prefetch() has been called. It is related + # to dummy_batch in trainer.py that uses the first batch when batch_by_size + # has been called but self.prefetch() has not. In this case we simply only + # load the queried samples into memory and don't do any caching. + self.prefetch_called = False @property def supports_prefetch(self): @@ -136,28 +140,31 @@ def prefetch(self, indices): assert self.size >= len(indices) self.ordered_indices = indices.copy() self.start_pos_for_next_cache = 0 + self.prefetched_called = True def __getitem__(self, i): self.check_index(i) + if not self.prefetch_called: # no caching + feat = kaldi_io.read_mat(self.rxfiles[i]) + return torch.from_numpy(feat).float() if i not in self.cache_index: - assert self.start_pos_for_next_cache < \ - len(self.ordered_indices), \ - 'Position for next cache starting beyond the end of ordered_indices.' + assert ( + self.start_pos_for_next_cache < len(self.ordered_indices) + ), "Position for next cache starting beyond the end of ordered_indices." try: pos_start = self.ordered_indices.index( i, self.start_pos_for_next_cache, ) except ValueError: raise ValueError( - 'index {} not found in self.ordered_indices. Set ' - 'self.ordered_prefetch to False, and/or call self.prefetch() ' - 'with the full list of indices, and then try again.'.format(i) + "index {} not found in self.ordered_indices. Set " + "self.ordered_prefetch to False, and/or call self.prefetch() " + "with the full list of indices, and then try again.".format(i) ) pos_end = min( pos_start + self.cache_size, len(self.ordered_indices), ) - self.start_pos_for_next_cache = pos_end \ - if self.ordered_prefetch else 0 + self.start_pos_for_next_cache = pos_end if self.ordered_prefetch else 0 total_size = 0 for idx in self.ordered_indices[pos_start: pos_end]: total_size += self.sizes[idx] @@ -251,20 +258,23 @@ def read_text(self, utt_ids: List[str], token_text: List[str], dictionary=None): self.sizes = np.array(self.sizes, dtype=np.int32) - assert len(self.utt_ids) == len(self.tokens_list) and \ - (dictionary is None or len(self.utt_ids) == len(self.tensor_list)) and \ - len(self.utt_ids) == len(self.sizes) + assert ( + len(self.utt_ids) == len(self.tokens_list) + and (dictionary is None or len(self.utt_ids) == len(self.tensor_list)) + and len(self.utt_ids) == len(self.sizes) + ) def check_index(self, i): if i < 0 or i >= self.size: - raise IndexError('index out of range') + raise IndexError("index out of range") def filter_and_reorder(self, indices): assert isinstance(indices, (list, np.ndarray)) indices = np.array(indices) assert all(indices < self.size) and all(indices >= 0) - assert len(np.unique(indices)) == len(indices), \ - 'Duplicate elements in indices.' + assert ( + len(np.unique(indices)) == len(indices) + ), "Duplicate elements in indices." self.utt_ids = [self.utt_ids[i] for i in indices] self.tokens_list = [self.tokens_list[i] for i in indices] if len(self.tensor_list) > 0: diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index ed24e4693..ee0348426 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -8,6 +8,7 @@ for decoding with Kaldi. """ +import ast import logging import os import sys @@ -44,7 +45,7 @@ def _main(args, output_file): utils.import_user_module(args) - if args.max_tokens is None and args.max_sentences is None: + if args.max_tokens is None and args.batch_size is None: args.max_tokens = 12000 logger.info(args) @@ -59,13 +60,17 @@ def _main(args, output_file): task = tasks.setup_task(args) task.load_dataset(args.gen_subset) + overrides = ast.literal_eval(args.model_overrides) + # Load ensemble logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), - arg_overrides=eval(args.model_overrides), + arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, ) # Load state prior for cross-entropy trained systems decoding @@ -76,11 +81,13 @@ def _main(args, output_file): # Optimize ensemble for generation for model in models: - model.prepare_for_inference_(args) + if model is None: + continue if args.fp16: model.half() - if use_cuda: + if use_cuda and not args.pipeline_model_parallel: model.cuda() + model.prepare_for_inference_(args) if isinstance(prior, list) and getattr(model, "state_prior", None) is not None: prior.append(model.state_prior.unsqueeze(0)) @@ -103,7 +110,7 @@ def _main(args, output_file): itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, - max_sentences=args.max_sentences, + max_sentences=args.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() if hasattr(model, "encoder") diff --git a/espresso/models/__init__.py b/espresso/models/__init__.py index 928f3caea..5dea15d65 100644 --- a/espresso/models/__init__.py +++ b/espresso/models/__init__.py @@ -10,5 +10,5 @@ # automatically import any Python files in the models/ directory for file in os.listdir(os.path.dirname(__file__)): if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): - model_name = file[:file.find(".py")] - importlib.import_module("espresso.models." + model_name) + file_name = file[: file.find(".py")] + importlib.import_module("espresso.models." + file_name) diff --git a/espresso/models/lstm_lm.py b/espresso/models/lstm_lm.py index 6fc9cca3b..11e0764ca 100644 --- a/espresso/models/lstm_lm.py +++ b/espresso/models/lstm_lm.py @@ -8,7 +8,7 @@ from typing import Optional from fairseq import utils -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass from fairseq.models import ( FairseqLanguageModel, register_model, @@ -70,7 +70,6 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): "dictionary from the task instance when calling cls.build_model()" }, ) - # Granular dropout settings (if not specified these default to --dropout) decoder_dropout_in: float = field( default=0.1, metadata={"help": "dropout probability for decoder input embedding"} @@ -83,58 +82,15 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") - # TODO common var add to parent tpu: bool = II("params.common.tpu") -@register_model("lstm_lm_espresso") +@register_model("lstm_lm_espresso", dataclass=LSTMLanguageModelEspressoConfig) class LSTMLanguageModelEspresso(FairseqLanguageModel): def __init__(self, decoder, args): super().__init__(decoder) self.is_wordlm = args.is_wordlm - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # fmt: off - parser.add_argument("--dropout", type=float, metavar="D", - help="dropout probability") - parser.add_argument("--decoder-embed-dim", type=int, metavar="N", - help="decoder embedding dimension") - parser.add_argument("--decoder-embed-path", type=str, metavar="STR", - help="path to pre-trained decoder embedding") - parser.add_argument("--decoder-freeze-embed", action="store_true", - help="freeze decoder embeddings") - parser.add_argument("--decoder-hidden-size", type=int, metavar="N", - help="decoder hidden size") - parser.add_argument("--decoder-layers", type=int, metavar="N", - help="number of decoder layers") - parser.add_argument("--decoder-out-embed-dim", type=int, metavar="N", - help="decoder output embedding dimension") - parser.add_argument("--decoder-rnn-residual", - type=lambda x: utils.eval_bool(x), - help="create residual connections for rnn decoder " - "layers (starting from the 2nd layer), i.e., the actual " - "output of such layer is the sum of its input and output") - parser.add_argument("--adaptive-softmax-cutoff", metavar="EXPR", - help="comma separated list of adaptive softmax cutoff points. " - "Must be used with adaptive_loss criterion") - parser.add_argument("--share-embed", - type=lambda x: utils.eval_bool(x), - help="share input and output embeddings") - parser.add_argument("--is-wordlm", action="store_true", - help="whether it is word LM or subword LM. Only " - "relevant for ASR decoding with LM, and it determines " - "how the underlying decoder instance gets the dictionary " - "from the task instance when calling cls.build_model()") - - # Granular dropout settings (if not specified these default to --dropout) - parser.add_argument("--decoder-dropout-in", type=float, metavar="D", - help="dropout probability for decoder input embedding") - parser.add_argument("--decoder-dropout-out", type=float, metavar="D", - help="dropout probability for decoder output") - # fmt: on - @classmethod def build_model(cls, args, task): """Build a new model instance.""" @@ -225,33 +181,39 @@ def lstm_lm_wsj(args): @register_model_architecture("lstm_lm_espresso", "lstm_lm_librispeech") def lstm_lm_librispeech(args): - args.dropout = getattr(args, "dropout", 0.0) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 800) - args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 800) - args.decoder_layers = getattr(args, "decoder_layers", 4) - args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 800) - args.share_embed = getattr(args, "share_embed", True) + args.dropout = 0.0 + args.decoder_embed_dim = 800 + args.decoder_hidden_size = 800 + args.decoder_layers = 4 + args.decoder_out_embed_dim = 800 + args.decoder_dropout_in = args.dropout + args.decoder_dropout_out = args.dropout + args.share_embed = True base_lm_architecture(args) @register_model_architecture("lstm_lm_espresso", "lstm_lm_swbd") def lstm_lm_swbd(args): - args.dropout = getattr(args, "dropout", 0.3) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1800) - args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1800) - args.decoder_layers = getattr(args, "decoder_layers", 3) - args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1800) - args.share_embed = getattr(args, "share_embed", True) + args.dropout = 0.3 + args.decoder_embed_dim = 1800 + args.decoder_hidden_size = 1800 + args.decoder_layers = 3 + args.decoder_out_embed_dim = 1800 + args.decoder_dropout_in = args.dropout + args.decoder_dropout_out = args.dropout + args.share_embed = True base_lm_architecture(args) @register_model_architecture("lstm_lm_espresso", "lstm_wordlm_wsj") def lstm_wordlm_wsj(args): - args.dropout = getattr(args, "dropout", 0.35) - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1200) - args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1200) - args.decoder_layers = getattr(args, "decoder_layers", 3) - args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1200) - args.share_embed = getattr(args, "share_embed", True) + args.dropout = 0.35 + args.decoder_embed_dim = 1200 + args.decoder_hidden_size = 1200 + args.decoder_layers = 3 + args.decoder_out_embed_dim = 1200 + args.decoder_dropout_in = args.dropout + args.decoder_dropout_out = args.dropout + args.share_embed = True args.is_wordlm = True base_lm_architecture(args) diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index b25d57ced..5386b7299 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -330,7 +330,12 @@ def get_attn_mask(self, in_lengths): all_ones.triu(self.transformer_context[1] + 1) | all_ones.tril(-self.transformer_context[0] - 1) ) - def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): + def forward( + self, + src_tokens, + src_lengths, + return_all_hiddens: bool = False, + ): """ Args: src_tokens (LongTensor): tokens in the source language of shape diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index 641e0acc3..989a35986 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -224,7 +224,12 @@ def __init__( self.fc_out = Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout_module.p) \ if num_targets is not None else None - def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False): + def forward( + self, + src_tokens, + src_lengths, + return_all_hiddens: bool = False, + ): """ Args: src_tokens (LongTensor): tokens in the source language of shape diff --git a/espresso/optim/__init__.py b/espresso/optim/__init__.py index f922fa16e..afbb9be64 100644 --- a/espresso/optim/__init__.py +++ b/espresso/optim/__init__.py @@ -10,5 +10,5 @@ # automatically import any Python files in the optim/ directory for file in os.listdir(os.path.dirname(__file__)): if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): - module = file[:file.find(".py")] - importlib.import_module("espresso.optim." + module) + file_name = file[: file.find(".py")] + importlib.import_module("espresso.optim." + file_name) diff --git a/espresso/optim/lr_scheduler/__init__.py b/espresso/optim/lr_scheduler/__init__.py index a67e46579..f73d9290f 100644 --- a/espresso/optim/lr_scheduler/__init__.py +++ b/espresso/optim/lr_scheduler/__init__.py @@ -10,5 +10,5 @@ # automatically import any Python files in the optim/lr_scheduler/ directory for file in os.listdir(os.path.dirname(__file__)): if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): - module = file[:file.find(".py")] - importlib.import_module("espresso.optim.lr_scheduler." + module) + file_name = file[: file.find(".py")] + importlib.import_module("espresso.optim.lr_scheduler." + file_name) diff --git a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index 2c9adff1b..e8e919860 100644 --- a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -9,7 +9,8 @@ import torch.optim.lr_scheduler -from fairseq.dataclass.utils import FairseqDataclass, gen_parser_from_dataclass +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.lr_scheduler import register_lr_scheduler from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateau @@ -54,7 +55,7 @@ class ReduceLROnPlateauV2Config(FairseqDataclass): lr: List[float] = II("params.optimization.lr") -@register_lr_scheduler('reduce_lr_on_plateau_v2') +@register_lr_scheduler("reduce_lr_on_plateau_v2", dataclass=ReduceLROnPlateauV2Config) class ReduceLROnPlateauV2(ReduceLROnPlateau): """Decay the LR by a factor every time the validation loss plateaus, starting from the epoch specified as args.start_reduce_lr_epoch. @@ -68,14 +69,15 @@ def __init__(self, args, optimizer): self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink, - mode='max' if args.maximize_best_checkpoint_metric else 'min', + mode="max" if args.maximize_best_checkpoint_metric else "min", threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0] ) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - gen_parser_from_dataclass(parser, ReduceLROnPlateauV2Config()) + @classmethod + def add_args(cls, parser): + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) def step(self, epoch, val_loss=None): if epoch < self.args.start_reduce_lr_epoch: diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 4279c946e..a08f10120 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -8,6 +8,8 @@ Recognize pre-processed speech with a trained model. """ +import ast +from itertools import chain import logging import math import os @@ -18,7 +20,6 @@ import torch from fairseq import checkpoint_utils, options, tasks, utils -from fairseq.data import encoders from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel @@ -64,7 +65,7 @@ def _main(args, output_file): utils.import_user_module(args) - if args.max_tokens is None and args.max_sentences is None: + if args.max_tokens is None and args.batch_size is None: args.max_tokens = 12000 logger.info(args) @@ -82,52 +83,79 @@ def _main(args, output_file): # Set dictionary dictionary = task.target_dictionary + overrides = ast.literal_eval(args.model_overrides) + # Load ensemble logger.info("loading model(s) from {}".format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), - arg_overrides=eval(args.model_overrides), + arg_overrides=overrides, task=task, suffix=getattr(args, "checkpoint_suffix", ""), + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, ) - for i, m in enumerate(models): + + if args.lm_path is not None: + overrides["data"] = args.data + + try: + lms, _ = checkpoint_utils.load_model_ensemble( + utils.split_paths(args.lm_path), + arg_overrides=overrides, + task=None, + ) + except: + logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({args.data})") + raise + + assert len(lms) == 1 or len(lms) == 2 # Multi-level LM expects two LMs + else: + lms = [None] + + for i, m in enumerate(lms): + if m is None: + continue if hasattr(m, "is_wordlm") and m.is_wordlm: # assume subword LM comes before word LM - if isinstance(models[i - 1], FairseqLanguageModel): - models[i-1] = MultiLevelLanguageModel( - m, models[i-1], + if i > 0 and isinstance(lms[i - 1], FairseqLanguageModel): + lms[i - 1] = MultiLevelLanguageModel( + m, lms[i - 1], subwordlm_weight=args.subwordlm_weight, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) - del models[i] + del lms[i] logger.info("LM fusion with Multi-level LM") else: - models[i] = TensorizedLookaheadLanguageModel( + lms[i] = TensorizedLookaheadLanguageModel( m, dictionary, oov_penalty=args.oov_penalty, open_vocab=not args.disable_open_vocab, ) logger.info("LM fusion with Look-ahead Word LM") - # assume subword LM comes after E2E models - elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel): + else: + assert isinstance(m, FairseqLanguageModel) logger.info("LM fusion with Subword LM") if args.lm_weight != 0.0: logger.info("using LM fusion with lm-weight={:.2f}".format(args.lm_weight)) # Optimize ensemble for generation - for model in models: - model.prepare_for_inference_(args) + for model in chain(models, lms): + if model is None: + continue if args.fp16: model.half() - if use_cuda: + if use_cuda and not args.pipeline_model_parallel: model.cuda() + model.prepare_for_inference_(args) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, - max_sentences=args.max_sentences, + max_sentences=args.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() if hasattr(model, "encoder") @@ -153,11 +181,21 @@ def _main(args, output_file): "The option match_source_len is not applicable to speech recognition. Ignoring it." ) gen_timer = StopwatchMeter() - generator = task.build_generator(models, args) + + extra_gen_cls_kwargs = { + "lm_model": lms[0], + "lm_weight": args.lm_weight, + "eos_factor": args.eos_factor, + } + args.score_reference = False # not applicable for ASR + temp_val = args.print_alignment + args.print_alignment = False # not applicable for ASR + generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) + args.print_alignment = temp_val # Handle tokenization and BPE - tokenizer = encoders.build_tokenizer(args) - bpe = encoders.build_bpe(args) + tokenizer = task.build_tokenizer(args) + bpe = task.build_bpe(args) def decode_fn(x): if bpe is not None: @@ -294,9 +332,6 @@ def cli_main(): parser.add_argument("--eos-factor", default=None, type=float, metavar="F", help="only consider emitting EOS if its score is no less " "than the specified factor of the best candidate score") - parser.add_argument("--lm-weight", default=0.0, type=float, metavar="W", - help="LM weight in log-prob space, assuming the pretrained " - "external LM is specified as the second one in --path") parser.add_argument("--subwordlm-weight", default=0.8, type=float, metavar="W", help="subword LM weight relative to word LM. Only relevant " "to MultiLevelLanguageModel as an external LM") diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 26763e92b..1c2b2b4f3 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -42,8 +42,8 @@ def main(args): utils.import_user_module(args) assert ( - args.max_tokens is not None or args.max_sentences is not None - ), "Must specify batch size either with --max-tokens or --max-sentences" + args.max_tokens is not None or args.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() @@ -100,7 +100,7 @@ def main(args): ) logger.info( "max input frames per GPU = {} and max sentences per GPU = {}".format( - args.max_tokens, args.max_sentences + args.max_tokens, args.batch_size ) ) diff --git a/espresso/tasks/__init__.py b/espresso/tasks/__init__.py index 6739bb677..34fcadf49 100644 --- a/espresso/tasks/__init__.py +++ b/espresso/tasks/__init__.py @@ -10,5 +10,5 @@ # automatically import any Python files in the tasks/ directory for file in os.listdir(os.path.dirname(__file__)): if not file.startswith("_") and not file.startswith(".") and file.endswith(".py"): - task_name = file[:file.find(".py")] - importlib.import_module("espresso.tasks." + task_name) + file_name = file[: file.find(".py")] + importlib.import_module("espresso.tasks." + file_name) diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 4afef4de0..4af8008b1 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -11,7 +11,6 @@ from fairseq import tokenizer, utils from fairseq.data import TruncatedDictionary -from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.tasks import register_task from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig @@ -26,7 +25,7 @@ class LanguageModelingForASRConfig(LanguageModelingConfig): dict: str = field(default=None, metadata={"help": "path to the dictionary"}) -@register_task("language_modeling_for_asr") +@register_task("language_modeling_for_asr", dataclass=LanguageModelingForASRConfig) class LanguageModelingForASRTask(LanguageModelingTask): """ Train a language model. @@ -56,11 +55,6 @@ class LanguageModelingForASRTask(LanguageModelingTask): :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser. Optionally register config store""" - gen_parser_from_dataclass(parser, LanguageModelingForASRConfig()) - def __init__(self, args, dictionary, output_dictionary=None, targets=None): super().__init__(args, dictionary, output_dictionary, targets=targets) torch.backends.cudnn.deterministic = True diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 544162be5..5df459456 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -11,7 +11,7 @@ import torch -from fairseq import search, utils +from fairseq import utils from fairseq.data import BaseWrapperDataset, ConcatDataset from fairseq.logging import metrics from fairseq.tasks import FairseqTask, register_task @@ -269,94 +269,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): unk_count += (tgt_dataset[i][0] == self.tgt_dict.unk()).int().sum().item() self.tgt_dict.count[self.tgt_dict.unk()] = unk_count - def build_generator( - self, models, args, - seq_gen_cls=None, extra_gen_cls_kwargs=None - ): - if getattr(args, "score_reference", False): - args.score_reference = False - logger.warning( - "--score-reference is not applicable to speech recognition, ignoring it." - ) - - from fairseq.sequence_generator import SequenceGenerator - - # Choose search strategy. Defaults to Beam Search. - sampling = getattr(args, "sampling", False) - sampling_topk = getattr(args, "sampling_topk", -1) - sampling_topp = getattr(args, "sampling_topp", -1.0) - diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) - diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) - match_source_len = getattr(args, "match_source_len", False) - diversity_rate = getattr(args, "diversity_rate", -1) - constrained = getattr(args, "constraints", False) - if ( - sum( - int(cond) - for cond in [ - sampling, - diverse_beam_groups > 0, - match_source_len, - diversity_rate > 0, - ] - ) - > 1 - ): - raise ValueError("Provided Search parameters are mutually exclusive.") - assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" - assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" - - if sampling: - search_strategy = search.Sampling( - self.target_dictionary, sampling_topk, sampling_topp - ) - elif diverse_beam_groups > 0: - search_strategy = search.DiverseBeamSearch( - self.target_dictionary, diverse_beam_groups, diverse_beam_strength - ) - elif match_source_len: - # this is useful for tagging applications where the output - # length should match the input length, so we hardcode the - # length constraints for simplicity - search_strategy = search.LengthConstrainedBeamSearch( - self.target_dictionary, - min_len_a=1, - min_len_b=0, - max_len_a=1, - max_len_b=0, - ) - elif diversity_rate > -1: - search_strategy = search.DiverseSiblingsSearch( - self.target_dictionary, diversity_rate - ) - elif constrained: - search_strategy = search.LexicallyConstrainedBeamSearch(self.target_dictionary, args.constraints) - else: - search_strategy = search.BeamSearch(self.target_dictionary) - - if seq_gen_cls is None: - seq_gen_cls = SequenceGenerator - extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} - extra_gen_cls_kwargs["lm_weight"] = getattr(args, "lm_weight", 0.0) - extra_gen_cls_kwargs["eos_factor"] = getattr(args, "eos_factor", None) - - return seq_gen_cls( - models, - self.target_dictionary, - beam_size=getattr(args, "beam", 5), - max_len_a=getattr(args, "max_len_a", 0), - max_len_b=getattr(args, "max_len_b", 200), - min_len=getattr(args, "min_len", 1), - normalize_scores=(not getattr(args, "unnormalized", False)), - len_penalty=getattr(args, "lenpen", 1), - unk_penalty=getattr(args, "unkpen", 0), - temperature=getattr(args, "temperature", 1.), - match_source_len=getattr(args, "match_source_len", False), - no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), - search_strategy=search_strategy, - **extra_gen_cls_kwargs, - ) - def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): return AsrDataset( src_tokens, src_lengths, dictionary=self.target_dictionary, constraints=constraints, diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index 07c802883..5314682d9 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -7,6 +7,8 @@ import logging import re +from fairseq.file_io import PathManager + import espresso.tools.utils as speech_utils @@ -30,7 +32,7 @@ def reset(self): def parse_wer_output_filter(self, wer_output_filter): if wer_output_filter: - with open(wer_output_filter, 'r', encoding='utf-8') as f: + with open(PathManager.get_local_path(wer_output_filter), 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line.startswith('#!') or line == '': @@ -143,7 +145,7 @@ def add_ordered_utt_list(self, *args): return self.ordered_utt_list = [] for text_file in args: - with open(text_file, 'r', encoding='utf-8') as f: + with open(PathManager.get_local_path(text_file), 'r', encoding='utf-8') as f: one_utt_list = [line.strip().split()[0] for line in f] self.ordered_utt_list.extend(one_utt_list) if len(self.char_results): diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 5be8af65f..13fc0e98e 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -176,8 +176,8 @@ if [ ${stage} -le 5 ]; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((16000/ngpus)) --log-format simple \ - --num-workers 0 --max-tokens 32000 --max-sentences 1024 --curriculum 1 \ - --valid-subset $valid_subset --max-sentences-valid 1536 \ + --num-workers 0 --max-tokens 32000 --batch-size 1024 --curriculum 1 \ + --valid-subset $valid_subset --batch-size-valid 1536 \ --distributed-world-size $ngpus \ --max-epoch 30 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ @@ -196,7 +196,7 @@ if [ ${stage} -le 6 ]; then log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ - --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ + --max-tokens 40960 --batch-size 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file done fi @@ -247,8 +247,8 @@ if [ ${stage} -le 8 ]; then fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval $((8000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ - --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 24 --curriculum 1 --empty-cache-freq 50 \ - --valid-subset $valid_subset --max-sentences-valid 48 --ddp-backend no_c10d --update-freq $update_freq \ + --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --batch-size 24 --curriculum 1 --empty-cache-freq 50 \ + --valid-subset $valid_subset --batch-size-valid 48 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((6000/ngpus/update_freq)) \ @@ -265,7 +265,7 @@ if [ ${stage} -le 9 ]; then path=$dir/$checkpoint decode_affix= if $lm_shallow_fusion; then - path="$path:$lmdir/$lm_checkpoint" + opts="$opts --lm-path $lmdir/$lm_checkpoint" opts="$opts --lm-weight 0.47 --eos-factor 1.5" if $apply_specaug; then # overwrite the existing opts @@ -276,7 +276,7 @@ if [ ${stage} -le 9 ]; then for dataset in $test_set; do decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --max-sentences 24 \ + --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --batch-size 24 \ --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index d59a218f5..95ab5403a 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -214,8 +214,8 @@ if [ $stage -le 4 ]; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((1000/ngpus)) --log-format simple \ - --num-workers 0 --max-tokens 25600 --max-sentences 1024 \ - --valid-subset $valid_subset --max-sentences-valid 1536 \ + --num-workers 0 --max-tokens 25600 --batch-size 1024 \ + --valid-subset $valid_subset --batch-size-valid 1536 \ --distributed-world-size $ngpus \ --max-epoch 25 --optimizer adam --lr 0.001 --clip-norm 1.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ @@ -234,7 +234,7 @@ if [ $stage -le 5 ]; then log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ - --max-tokens 40960 --max-sentences 1536 --sample-break-mode eos \ + --max-tokens 40960 --batch-size 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file done fi @@ -286,8 +286,8 @@ if [ $stage -le 7 ]; then fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval $((3000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ - --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --max-sentences 48 --curriculum 2 --empty-cache-freq 50 \ - --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ + --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --batch-size 48 --curriculum 2 --empty-cache-freq 50 \ + --valid-subset $valid_subset --batch-size-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus \ --optimizer adam --lr 0.001 --weight-decay 0.0 --clip-norm 2.0 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((3000/ngpus/update_freq)) \ @@ -305,7 +305,7 @@ if [ $stage -le 8 ]; then path=$dir/$checkpoint decode_affix= if $lm_shallow_fusion; then - path="$path:$lmdir/$lm_checkpoint" + opts="$opts --lm-path $lmdir/$lm_checkpoint" opts="$opts --lm-weight 0.25" decode_affix=shallow_fusion fi @@ -313,7 +313,7 @@ if [ $stage -le 8 ]; then for dataset in $test_set; do decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --max-sentences 48 \ + --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --batch-size 48 \ --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --non-lang-syms $nlsyms --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/local/wsj_extend_dict.sh b/examples/asr_wsj/local/wsj_extend_dict.sh new file mode 120000 index 000000000..f3b575882 --- /dev/null +++ b/examples/asr_wsj/local/wsj_extend_dict.sh @@ -0,0 +1 @@ +../../../espresso/tools/kaldi/egs/wsj/s5/local/wsj_extend_dict.sh \ No newline at end of file diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index 6c372a198..ae3ea5fe7 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -192,8 +192,8 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((4000/ngpus)) --log-format simple \ - --num-workers 0 --max-tokens 25600 --max-sentences 128 \ - --valid-subset $valid_subset --max-sentences-valid 256 \ + --num-workers 0 --max-tokens 25600 --batch-size 128 \ + --valid-subset $valid_subset --batch-size-valid 256 \ --distributed-world-size $ngpus \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 5e-06 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ @@ -208,7 +208,7 @@ if [ ${stage} -le 5 ] && ! $use_wordlm; then log_file=$lmdir/log/evaluation_$gen_subset.log python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ - --max-tokens 192000 --max-sentences 256 --sample-break-mode eos \ + --max-tokens 192000 --batch-size 256 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file done fi @@ -222,8 +222,8 @@ if [ ${stage} -le 6 ] && $use_wordlm; then CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $wordlmdatadir --seed 1 --user-dir espresso \ --task language_modeling_for_asr --dict $wordlmdict \ --log-interval $((4000/ngpus)) --log-format simple \ - --num-workers 0 --max-tokens 6400 --max-sentences 256 \ - --valid-subset $valid_subset --max-sentences-valid 512 \ + --num-workers 0 --max-tokens 6400 --batch-size 256 \ + --valid-subset $valid_subset --batch-size-valid 512 \ --distributed-world-size $ngpus \ --max-epoch 25 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ @@ -239,7 +239,7 @@ if [ ${stage} -le 7 ] && $use_wordlm; then log_file=$wordlmdir/log/evaluation_$gen_subset.log python3 ../../fairseq_cli/eval_lm.py $wordlmdatadir --user-dir espresso --cpu \ --task language_modeling_for_asr --dict $wordlmdict --gen-subset $gen_subset \ - --max-tokens 12800 --max-sentences 512 --sample-break-mode eos \ + --max-tokens 12800 --batch-size 512 --sample-break-mode eos \ --path $wordlmdir/$lm_checkpoint 2>&1 | tee $log_file done fi @@ -285,8 +285,8 @@ if [ ${stage} -le 9 ]; then fi CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ --log-interval $((800/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((2000/ngpus/update_freq)) \ - --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --max-sentences 32 --curriculum 2 --empty-cache-freq 50 \ - --valid-subset $valid_subset --max-sentences-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ + --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --batch-size 32 --curriculum 2 --empty-cache-freq 50 \ + --valid-subset $valid_subset --batch-size-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus \ --optimizer adam --lr 0.001 --weight-decay 0.0 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((800/ngpus/update_freq)) \ @@ -303,11 +303,11 @@ if [ ${stage} -le 10 ]; then decode_affix= if $lm_shallow_fusion; then if ! $use_wordlm; then - path="$path:$lmdir/$lm_checkpoint" + opts="$opts --lm-path $lmdir/$lm_checkpoint" opts="$opts --lm-weight 0.7 --eos-factor 1.5" decode_affix=shallow_fusion else - path="$path:$wordlmdir/$lm_checkpoint" + opts="$opts --lm-path $wordlmdir/$lm_checkpoint" opts="$opts --word-dict $wordlmdict --lm-weight 0.9 --oov-penalty 1e-7 --eos-factor 1.5" decode_affix=shallow_fusion_wordlm fi @@ -316,7 +316,7 @@ if [ ${stage} -le 10 ]; then for dataset in $valid_set $test_set; do decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --max-sentences 32 \ + --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --batch-size 32 \ --num-shards 1 --shard-id 0 --dict $dict --bpe characters_asr --non-lang-syms $nlsyms \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh index 51a550ff7..3d524db80 100755 --- a/examples/asr_wsj/run_chain_e2e.sh +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -187,8 +187,8 @@ if [ ${stage} -le 6 ]; then update_freq=1 CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e --task speech_recognition_hybrid --seed 1 --user-dir espresso \ --log-interval $((200/ngpus/update_freq)) --log-format simple \ - --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --max-sentences 128 --curriculum 1 --empty-cache-freq 50 \ - --valid-subset $valid_subset --max-sentences-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ + --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --batch-size 128 --curriculum 1 --empty-cache-freq 50 \ + --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus \ --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 --start-reduce-lr-epoch 11 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ @@ -213,7 +213,7 @@ if [ ${stage} -le 7 ]; then graph_dir=$tree_dir/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ dump_posteriors.py data/chain_e2e --cpu --task speech_recognition_hybrid --user-dir espresso \ - --max-tokens 120000 --max-sentences 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ + --max-tokens 120000 --batch-size 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB \ --max-source-positions 9999 --path $path \| \ latgen-faster-mapped --max-active=7000 --min-active=20 --beam=15 --lattice-beam=8 --acoustic-scale=1.0 \ diff --git a/examples/asr_wsj/run_chain_e2e_bichar.sh b/examples/asr_wsj/run_chain_e2e_bichar.sh index 4192a7103..8829f28e6 100755 --- a/examples/asr_wsj/run_chain_e2e_bichar.sh +++ b/examples/asr_wsj/run_chain_e2e_bichar.sh @@ -187,8 +187,8 @@ if [ ${stage} -le 6 ]; then update_freq=1 CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e_bichar --task speech_recognition_hybrid --seed 1 --user-dir espresso \ --log-interval $((200/ngpus/update_freq)) --log-format simple \ - --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --max-sentences 128 --curriculum 1 --empty-cache-freq 50 \ - --valid-subset $valid_subset --max-sentences-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ + --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --batch-size 128 --curriculum 1 --empty-cache-freq 50 \ + --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus \ --max-epoch 30 --optimizer adam --lr 0.001 --weight-decay 0.0 --start-reduce-lr-epoch 11 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ @@ -213,7 +213,7 @@ if [ ${stage} -le 7 ]; then graph_dir=$tree_dir/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ dump_posteriors.py data/chain_e2e_bichar --cpu --task speech_recognition_hybrid --user-dir espresso \ - --max-tokens 120000 --max-sentences 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ + --max-tokens 120000 --batch-size 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB \ --max-source-positions 9999 --path $path \| \ latgen-faster-mapped --max-active=7000 --min-active=20 --beam=15 --lattice-beam=8 --acoustic-scale=1.0 \ diff --git a/examples/asr_wsj/run_xent.sh b/examples/asr_wsj/run_xent.sh index 63fbcb4f6..d66b937a3 100755 --- a/examples/asr_wsj/run_xent.sh +++ b/examples/asr_wsj/run_xent.sh @@ -167,8 +167,8 @@ if [ ${stage} -le 5 ]; then update_freq=1 CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/xent --task speech_recognition_hybrid --seed 1 --user-dir espresso \ --log-interval $((100/ngpus/update_freq)) --log-format simple \ - --num-workers 0 --data-buffer-size 0 --max-tokens 160000 --max-sentences 256 --empty-cache-freq 50 \ - --valid-subset $valid_subset --max-sentences-valid 256 --ddp-backend no_c10d --update-freq $update_freq \ + --num-workers 0 --data-buffer-size 0 --max-tokens 160000 --batch-size 256 --empty-cache-freq 50 \ + --valid-subset $valid_subset --batch-size-valid 256 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus \ --max-epoch 40 --optimizer adam --lr 0.001 --weight-decay 0.0 \ --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ @@ -193,7 +193,7 @@ if [ ${stage} -le 6 ]; then graph_dir=exp/$gmm/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ dump_posteriors.py data/xent --cpu --task speech_recognition_hybrid --user-dir espresso \ - --max-tokens 256000 --max-sentences 256 --num-shards 1 --shard-id 0 --num-targets $num_targets \ + --max-tokens 256000 --batch-size 256 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB --chunk-width 150 --chunk-left-context 10 --chunk-right-context 10 --label-delay -3 \ --max-source-positions 9999 --path $path --apply-log-softmax \| \ latgen-faster-mapped --max-active=7000 --min-active=20 --beam=15 --lattice-beam=8 --acoustic-scale=0.1 \ diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index f94e9f873..586e189f8 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -13,6 +13,8 @@ from fairseq.models import FairseqIncrementalDecoder from torch import Tensor +from espresso.models.external_language_model import RawOutExternalLanguageModelBase + class SequenceGenerator(nn.Module): def __init__( @@ -62,8 +64,7 @@ def __init__( if isinstance(models, EnsembleModel): self.model = models else: - lm_weight = kwargs.get("lm_weight", 0.0) - self.model = EnsembleModel(models) if lm_weight == 0.0 else LMFusionModel(models, lm_weight) + self.model = EnsembleModel(models) self.tgt_dict = tgt_dict self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() @@ -193,6 +194,10 @@ def _generate( for i in range(self.model.models_size) ], ) + lm_incremental_state = torch.jit.annotate( + Dict[str, Dict[str, Optional[Tensor]]], + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + ) if self.lm_model is not None else None net_input = sample["net_input"] if "src_tokens" in net_input: @@ -310,6 +315,10 @@ def _generate( ) original_batch_idxs = original_batch_idxs[batch_idxs] self.model.reorder_incremental_state(incremental_states, reorder_state) + if self.lm_model is not None: + self.lm_model.decoder.reorder_incremental_state_scripting( + lm_incremental_state, reorder_state + ) encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) @@ -322,10 +331,13 @@ def _generate( ) if self.lm_model is not None: - lm_out = self.lm_model(tokens[:, : step + 1]) - probs = self.lm_model.get_normalized_probs( - lm_out, log_probs=True, sample=None - ) + lm_out = self.lm_model(tokens[:, : step + 1], incremental_state=lm_incremental_state) + if isinstance(self.lm_model, RawOutExternalLanguageModelBase): + probs = lm_out[0] + else: + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) probs = probs[:, -1, :] * self.lm_weight lprobs += probs @@ -1019,105 +1031,3 @@ def forward_align(self, src_tokens, src_lengths, prev_output_tokens): if len(self.models) > 1: avg_attn.div_(len(self.models)) return avg_attn - - -class LMFusionModel(EnsembleModel): - """A wrapper around an ensemble of an LM fused model.""" - - def __init__(self, models, lm_weight): - super().__init__(models) - self.lm_weight = lm_weight - assert self.models_size == 2, "Only support LM fusion with one E2E model" - assert self.has_encoder() - - @torch.jit.export - def forward_encoder(self, net_input: Dict[str, Tensor]): - return [ - model.encoder.forward_torchscript(net_input) if hasattr(model, "encoder") - else None for model in self.models - ] - - @torch.jit.export - def forward_decoder( - self, - tokens, - encoder_outs: List[EncoderOut], - incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], - temperature: float = 1.0, - ): - log_probs = [] - avg_attn: Optional[Tensor] = None - attn_count = 0 - encoder_out: Optional[EncoderOut] = None - for i, model in enumerate(self.models): - encoder_out = encoder_outs[i] - # decode each model - if self.has_incremental_states(): - decoder_out = model.decoder.forward( - tokens, - encoder_out=encoder_out, - incremental_state=incremental_states[i], - ) - else: - decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) - - attn: Optional[Tensor] = None - decoder_len = len(decoder_out) - if decoder_len > 1 and decoder_out[1] is not None: - if isinstance(decoder_out[1], Tensor): - attn = decoder_out[1] - else: - attn_holder = decoder_out[1]["attn"] - if isinstance(attn_holder, Tensor): - attn = attn_holder - elif attn_holder is not None: - attn = attn_holder[0] - if attn is not None: - attn = attn[:, -1, :] - - decoder_out_tuple = ( - decoder_out[0][:, -1:, :].div_(temperature), - None if decoder_len <= 1 else decoder_out[1], - ) - - from espresso.models.external_language_model import RawOutExternalLanguageModelBase - if isinstance(model, RawOutExternalLanguageModelBase): - probs = decoder_out_tuple[0] - else: - probs = model.get_normalized_probs( - decoder_out_tuple, log_probs=True, sample=None - ) - probs = probs[:, -1, :] - if i == 1 and self.lm_weight != 1.0: # assuming LM is the last model - probs.mul_(self.lm_weight) - - log_probs.append(probs) - if attn is not None: - if avg_attn is None: - avg_attn = attn - else: - avg_attn.add_(attn) - attn_count += 1 - avg_probs = torch.sum(torch.stack(log_probs, dim=0), dim=0) - if avg_attn is not None: - avg_attn.div_(attn_count) - return avg_probs, avg_attn - - @torch.jit.export - def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order): - """ - Reorder encoder output according to *new_order*. - - Args: - encoder_out: output from the ``forward()`` method - new_order (LongTensor): desired order - - Returns: - *encoder_out* rearranged according to *new_order* - """ - new_outs: List[EncoderOut] = [] - for i, model in enumerate(self.models): - new_outs.append( - model.encoder.reorder_encoder_out(encoder_outs[i], new_order) if hasattr(model, "encoder") else None - ) - return new_outs From 262c0a29dcfd6de4c266b1dda8ede5f6515ab230 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 25 Oct 2020 22:27:48 -0400 Subject: [PATCH 102/119] code adaptation/changes according to the commits on Oct 18-Nov 3, 2020 (lots of changes, mostly for adapting to hydra configs and code formatting) --- .../label_smoothed_cross_entropy_v2.py | 35 +- espresso/criterions/lf_mmi_loss.py | 9 +- espresso/data/asr_chain_dataset.py | 50 ++- espresso/data/asr_dataset.py | 78 +++-- espresso/data/asr_dictionary.py | 31 +- espresso/data/asr_xent_dataset.py | 101 ++++-- espresso/data/encoders/characters_asr.py | 15 +- espresso/data/feat_text_dataset.py | 8 +- espresso/dump_posteriors.py | 108 +++--- espresso/models/external_language_model.py | 104 +++--- espresso/models/lstm_lm.py | 106 ++++-- espresso/models/speech_lstm.py | 156 ++++++--- espresso/models/speech_lstm_encoder_model.py | 85 +++-- espresso/models/speech_tdnn.py | 37 +- espresso/models/speech_transformer.py | 25 +- .../speech_transformer_encoder_model.py | 53 ++- .../tensorized_lookahead_language_model.py | 33 +- espresso/modules/__init__.py | 4 +- espresso/modules/speech_attention.py | 22 +- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 27 +- espresso/speech_recognize.py | 210 ++++++------ espresso/speech_train.py | 216 ++++++------ espresso/tasks/language_modeling_for_asr.py | 3 +- espresso/tasks/speech_recognition.py | 196 +++++++---- espresso/tasks/speech_recognition_hybrid.py | 321 +++++++++++------- espresso/tools/compute_wer.py | 64 ++-- ...ate_initial_state_prior_from_alignments.py | 2 +- espresso/tools/lexical_prefix_tree.py | 11 +- espresso/tools/text2token.py | 47 +-- espresso/tools/text2vocabulary.py | 98 +++--- espresso/tools/utils.py | 124 +++---- espresso/tools/wer.py | 107 +++--- examples/asr_librispeech/run.sh | 10 +- examples/asr_swbd/run.sh | 10 +- examples/asr_wsj/run.sh | 16 +- examples/asr_wsj/run_chain_e2e.sh | 4 +- examples/asr_wsj/run_chain_e2e_bichar.sh | 4 +- examples/asr_wsj/run_xent.sh | 6 +- fairseq/dataclass/configs.py | 41 +++ tests/espresso/test_speech_utils.py | 8 +- 40 files changed, 1562 insertions(+), 1023 deletions(-) diff --git a/espresso/criterions/label_smoothed_cross_entropy_v2.py b/espresso/criterions/label_smoothed_cross_entropy_v2.py index cb29b46db..96285fe2b 100644 --- a/espresso/criterions/label_smoothed_cross_entropy_v2.py +++ b/espresso/criterions/label_smoothed_cross_entropy_v2.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II import logging import numpy as np @@ -16,6 +15,7 @@ from fairseq.data import data_utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import II logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ @dataclass class LabelSmoothedCrossEntropyV2CriterionConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") + sentence_avg: bool = II("optimization.sentence_avg") label_smoothing: float = field( default=0.0, metadata={ @@ -85,7 +85,7 @@ def temporal_label_smoothing_prob_mask( prob_mask[:, :, padding_index] = 0 # clear cumulative count on prob_mask = prob_mask.float() # convert to float sum_prob = prob_mask.sum(-1, keepdim=True) - sum_prob[sum_prob.squeeze(-1).eq(0.)] = 1. # to deal with the "division by 0" problem + sum_prob[sum_prob.squeeze(-1).eq(0.0)] = 1.0 # to deal with the "division by 0" problem prob_mask = prob_mask.div_(sum_prob).view(-1, prob_mask.size(-1)) return prob_mask @@ -109,8 +109,8 @@ def label_smoothed_nll_loss( raise ValueError("Unsupported smoothing type: {}".format(smoothing_type)) if ignore_index is not None: pad_mask = target.eq(ignore_index) - nll_loss.masked_fill_(pad_mask, 0.) - smooth_loss.masked_fill_(pad_mask, 0.) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) @@ -118,7 +118,7 @@ def label_smoothed_nll_loss( nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() eps_i = epsilon / lprobs.size(-1) if smoothing_type == "uniform" else epsilon - loss = (1. - epsilon) * nll_loss + eps_i * smooth_loss + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss return loss, nll_loss @@ -126,9 +126,15 @@ def label_smoothed_nll_loss( class LabelSmoothedCrossEntropyV2Criterion(LabelSmoothedCrossEntropyCriterion): def __init__( - self, task, sentence_avg, label_smoothing, smoothing_type, - print_training_sample_interval, unigram_pseudo_count, - ignore_prefix_size=0, report_accuracy=False, + self, + task, + sentence_avg, + label_smoothing, + smoothing_type, + print_training_sample_interval, + unigram_pseudo_count, + ignore_prefix_size=0, + report_accuracy=False, ): super().__init__( task, sentence_avg, label_smoothing, @@ -149,7 +155,7 @@ def __init__( @classmethod def add_args(cls, parser): """Add criterion-specific arguments to the parser.""" - dc = getattr(cls, '__dataclass', None) + dc = getattr(cls, "__dataclass", None) if dc is not None: gen_parser_from_dataclass(parser, dc()) @@ -212,8 +218,13 @@ def compute_loss( padding_index=self.padding_idx, ) if smoothing_type == "temporal" else None loss, nll_loss = label_smoothed_nll_loss( - lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce, - smoothing_type=smoothing_type, prob_mask=prob_mask, + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + smoothing_type=smoothing_type, + prob_mask=prob_mask, unigram_tensor=self.unigram_tensor, ) return loss, nll_loss, lprobs diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index 997b1a422..dd00e9188 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II import logging import math @@ -14,6 +13,7 @@ from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.logging import metrics +from omegaconf import II logger = logging.getLogger(__name__) @@ -21,9 +21,9 @@ @dataclass class LatticeFreeMMICriterionConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") + sentence_avg: bool = II("optimization.sentence_avg") denominator_fst_path: str = field( - default=None, metadata={"help": "path to the denominator fst file"} + default="???", metadata={"help": "path to the denominator fst file"} ) leaky_hmm_coefficient: float = field( default=1.0e-05, @@ -215,10 +215,11 @@ def compute_loss(self, net_output, sample, reduce=True): def reduce_metrics(cls, logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" loss_sum = sum(log.get("loss", 0) for log in logging_outputs) - nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # we divide by log(2) to convert the loss from base e to base 2 metrics.log_scalar( "loss", loss_sum / sample_size / math.log(2), sample_size, round=7 ) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 2ed83741f..5b88580fc 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -12,7 +12,7 @@ import torch -from fairseq.data import data_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils import espresso.tools.utils as speech_utils @@ -48,12 +48,15 @@ def merge(key, pad_to_length=None): raise ValueError("Invalid key.") id = torch.LongTensor([s["id"] for s in samples]) - src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) # sort by descending source length if pad_to_length is not None or src_bucketed: - src_lengths = torch.IntTensor([ - s["source"].ne(0.0).any(dim=1).int().sum() for s in samples - ]) + src_lengths = torch.IntTensor( + [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples] + ) else: src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) @@ -134,8 +137,7 @@ def filter_and_reorder(self, indices): assert isinstance(indices, (list, np.ndarray)) indices = np.array(indices) assert all(indices < len(self.utt_ids)) and all(indices >= 0) - assert len(np.unique(indices)) == len(indices), \ - "Duplicate elements in indices." + assert len(np.unique(indices)) == len(indices), "Duplicate elements in indices." self.utt_ids = [self.utt_ids[i] for i in indices] self.rxfiles = [self.rxfiles[i] for i in indices] self.numerator_graphs = [self.numerator_graphs[i] for i in indices] @@ -172,8 +174,15 @@ class AsrChainDataset(FairseqDataset): """ def __init__( - self, src, src_sizes, tgt=None, tgt_sizes=None, text=None, shuffle=True, - num_buckets=0, pad_to_multiple=1, + self, + src, + src_sizes, + tgt=None, + tgt_sizes=None, + text=None, + shuffle=True, + num_buckets=0, + pad_to_multiple=1, ): self.src = src self.tgt = tgt @@ -196,10 +205,15 @@ def __init__( "Removed {} examples due to empty numerator graphs or missing entries, " "{} remaining".format(num_removed, num_after_matching) ) - self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, @@ -215,8 +229,7 @@ def __init__( num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ - (None, num_tokens) - for num_tokens in np.unique(self.bucketed_num_tokens) + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None @@ -293,7 +306,7 @@ def collater(self, samples, pad_to_length=None): Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of - {'source': source_pad_to_length} + {"source": source_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: @@ -327,7 +340,10 @@ def num_tokens(self, index): def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" - return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based @@ -339,9 +355,7 @@ def ordered_indices(self): if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: - indices = indices[ - np.argsort(self.tgt_sizes[indices], kind="mergesort") - ] + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len @@ -358,7 +372,7 @@ def prefetch(self, indices): self.src.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index b0dc9259a..0246ccb88 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -7,7 +7,7 @@ import numpy as np import torch -from fairseq.data import data_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils import espresso.tools.utils as speech_utils @@ -48,14 +48,15 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): id = torch.LongTensor([s["id"] for s in samples]) src_frames = merge( - "source", left_pad=left_pad_source, + "source", + left_pad=left_pad_source, pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, ) # sort by descending source length if pad_to_length is not None or src_bucketed: - src_lengths = torch.IntTensor([ - s["source"].ne(0.0).any(dim=1).int().sum() for s in samples - ]) + src_lengths = torch.IntTensor( + [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples] + ) else: src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) src_lengths, sort_order = src_lengths.sort(descending=True) @@ -68,7 +69,9 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): if samples[0].get("target", None) is not None: target = merge( "target", left_pad=left_pad_target, - pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) target = target.index_select(0, sort_order) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) @@ -82,7 +85,9 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): "target", left_pad=left_pad_target, move_eos_to_beginning=True, - pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, ) else: ntokens = src_lengths.sum().item() @@ -104,7 +109,9 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): "target_raw_text": target_raw_text, } if prev_output_tokens is not None: - batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(0, sort_order) + batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( + 0, sort_order + ) if samples[0].get("constraints", None) is not None: # Collate the packed constraints across the samples, padding to @@ -112,7 +119,7 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): lens = [sample.get("constraints").size(0) for sample in samples] constraints = torch.zeros((len(samples), max(lens))).long() for i, sample in enumerate(samples): - constraints[i, 0:lens[i]] = samples[i].get("constraints") + constraints[i, 0: lens[i]] = samples[i].get("constraints") batch["constraints"] = constraints return batch @@ -141,19 +148,25 @@ class AsrDataset(FairseqDataset): num_buckets (int, optional): if set to a value greater than 0, then batches will be bucketed into the given number of batch shapes. src_lang_id (int, optional): source language ID, if set, the collated batch - will contain a field 'src_lang_id' in 'net_input' which indicates the + will contain a field "src_lang_id" in "net_input" which indicates the source language of the samples. tgt_lang_id (int, optional): target language ID, if set, the collated batch - will contain a field 'tgt_lang_id' which indicates the target language + will contain a field "tgt_lang_id" which indicates the target language of the samples. pad_to_multiple (int, optional): pad src/tgt lengths to a multiple of this value """ def __init__( - self, src, src_sizes, - tgt=None, tgt_sizes=None, dictionary=None, - left_pad_source=False, left_pad_target=False, - shuffle=True, input_feeding=True, + self, + src, + src_sizes, + tgt=None, + tgt_sizes=None, + dictionary=None, + left_pad_source=False, + left_pad_target=False, + shuffle=True, + input_feeding=True, constraints=None, num_buckets=0, src_lang_id=None, @@ -175,10 +188,15 @@ def __init__( self.tgt_lang_id = tgt_lang_id if self.tgt is not None: self._match_src_tgt() - self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) if num_buckets > 0: from espresso.data import FeatBucketPadLengthDataset, TextBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, @@ -204,8 +222,7 @@ def __init__( num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ - (None, num_tokens) - for num_tokens in np.unique(self.bucketed_num_tokens) + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None @@ -261,7 +278,7 @@ def collater(self, samples, pad_to_length=None): Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of - {'source': source_pad_to_length, 'target': target_pad_to_length} + {"source": source_pad_to_length, "target": target_pad_to_length} to indicate the max length to pad to in source and target respectively. Returns: @@ -309,13 +326,13 @@ def collater(self, samples, pad_to_length=None): src_tokens = res["net_input"]["src_tokens"] bsz = src_tokens.size(0) if self.src_lang_id is not None: - res["net_input"]["src_lang_id"] = torch.LongTensor( - [[self.src_lang_id]] - ).expand(bsz, 1).to(src_tokens) + res["net_input"]["src_lang_id"] = ( + torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) + ) if self.tgt_lang_id is not None: - res["tgt_lang_id"] = torch.LongTensor( - [[self.tgt_lang_id]] - ).expand(bsz, 1).to(src_tokens) + res["tgt_lang_id"] = ( + torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) + ) return res def num_tokens(self, index): @@ -326,7 +343,10 @@ def num_tokens(self, index): def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" - return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based @@ -338,9 +358,7 @@ def ordered_indices(self): if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: - indices = indices[ - np.argsort(self.tgt_sizes[indices], kind="mergesort") - ] + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len @@ -357,7 +375,7 @@ def prefetch(self, indices): self.src.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: diff --git a/espresso/data/asr_dictionary.py b/espresso/data/asr_dictionary.py index 3c0a07912..4cdd44dac 100644 --- a/espresso/data/asr_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -3,10 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch +from argparse import Namespace +from typing import Union +import torch from fairseq.data import Dictionary, encoders from fairseq.file_io import PathManager +from omegaconf import DictConfig # will automatically load modules defined from there from espresso.data import encoders as encoders_espresso @@ -24,8 +27,9 @@ def __init__( space="", extra_special_symbols=None, ): - self.bos_word, self.unk_word, self.pad_word, self.eos_word, self.space_word = \ + self.bos_word, self.unk_word, self.pad_word, self.eos_word, self.space_word = ( bos, unk, pad, eos, space + ) self.symbols = [] self.count = [] self.indices = {} @@ -78,12 +82,13 @@ def load(cls, f, f_non_lang_syms=None): except UnicodeError: raise Exception( "Incorrect encoding detected in {}, please " - "rebuild the dataset".format(f) + "rebuild the dataset".format(fd) ) for sym in non_lang_syms: - assert d.index(sym) != d.unk(), \ - "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) + assert ( + d.index(sym) != d.unk() + ), "{} in {} is not in the dictionary".format(sym, f_non_lang_syms) d.non_lang_syms = non_lang_syms return d @@ -94,17 +99,19 @@ def dummy_sentence(self, length): t[-1] = self.eos() return t - def build_tokenizer(self, args): - self.tokenizer = encoders.build_tokenizer(args) + def build_tokenizer(self, cfg: Union[DictConfig, Namespace]): + self.tokenizer = encoders.build_tokenizer(cfg) - def build_bpe(self, args): - if args.bpe == "characters_asr": + def build_bpe(self, cfg: Union[DictConfig, Namespace]): + if ( + (isinstance(cfg, DictConfig) and cfg._name == "characters_asr") + or (isinstance(cfg, Namespace) and getattr(cfg, "bpe", None) == "characters_asr") + ): self.bpe = encoders.build_bpe( - args, space_symbol=self.space_word, ends_with_space=True, - non_lang_syms=self.non_lang_syms, + cfg, space_symbol=self.space_word, non_lang_syms=self.non_lang_syms ) else: - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg) def wordpiece_encode(self, x): if self.tokenizer is not None: diff --git a/espresso/data/asr_xent_dataset.py b/espresso/data/asr_xent_dataset.py index e831345f9..47ed464f3 100644 --- a/espresso/data/asr_xent_dataset.py +++ b/espresso/data/asr_xent_dataset.py @@ -12,7 +12,7 @@ import torch import torch.nn.functional as F -from fairseq.data import data_utils, FairseqDataset +from fairseq.data import FairseqDataset, data_utils import espresso.tools.utils as speech_utils @@ -111,18 +111,26 @@ def chunking(src_item, tgt_item, tgt_start): s["source"] = src_item[: label_delay] if pad_to_length is not None or src_bucketed: - src_lengths = torch.IntTensor([ - s["source"].ne(0.0).any(dim=1).int().sum() for s in samples - ]) + src_lengths = torch.IntTensor( + [s["source"].ne(0.0).any(dim=1).int().sum() for s in samples] + ) else: src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) id = torch.LongTensor([s["id"] for s in samples]) utt_id = [s["utt_id"] for s in samples] - src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) target = None if samples[0].get("target", None) is not None: - target = merge("target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None) + target = merge( + "target", + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() @@ -181,14 +189,25 @@ def chunking(src_item, tgt_item, tgt_start): s["source"] = ori_source[i].new_zeros( chunk_width + chunk_left_context + chunk_right_context, ori_source[i].size(1) ) - s["target"] = ori_target[i].new_full((chunk_width,), pad_idx) \ - if ori_target[i] is not None else None - src_frames = merge("source", pad_to_length=pad_to_length["source"] if pad_to_length is not None else None) + s["target"] = ( + ori_target[i].new_full((chunk_width,), pad_idx) + if ori_target[i] is not None + else None + ) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) src_chunk_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) target = None if samples[0].get("target", None) is not None: - target = merge("target", pad_to_length=pad_to_length["target"] if pad_to_length is not None else None) + target = merge( + "target", + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) ntokens = sum(s["target"].ne(pad_idx).int().sum().item() for s in samples) else: ntokens = src_lengths.sum().item() @@ -218,8 +237,12 @@ class AliScpCachedDataset(torch.utils.data.Dataset): """ def __init__( - self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, - ordered_prefetch=False, cache_size=327680, + self, + utt_ids: List[str], + rxfiles: List[str], + utt2num_frames: Optional[List[int]] = None, + ordered_prefetch=False, + cache_size=327680, ): super().__init__() assert len(utt_ids) == len(rxfiles) @@ -277,8 +300,7 @@ def filter_and_reorder(self, indices): assert isinstance(indices, (list, np.ndarray)) indices = np.array(indices) assert all(indices < len(self.utt_ids)) and all(indices >= 0) - assert len(np.unique(indices)) == len(indices), \ - "Duplicate elements in indices." + assert len(np.unique(indices)) == len(indices), "Duplicate elements in indices." self.utt_ids = [self.utt_ids[i] for i in indices] self.rxfiles = [self.rxfiles[i] for i in indices] self.sizes = self.sizes[indices] @@ -288,9 +310,9 @@ def filter_and_reorder(self, indices): def __getitem__(self, i): self.check_index(i) if i not in self.cache_index: - assert self.start_pos_for_next_cache < \ - len(self.ordered_indices), \ - "Position for next cache starting beyond the end of ordered_indices." + assert ( + self.start_pos_for_next_cache < len(self.ordered_indices) + ), "Position for next cache starting beyond the end of ordered_indices." try: pos_start = self.ordered_indices.index( i, self.start_pos_for_next_cache, @@ -304,8 +326,7 @@ def __getitem__(self, i): pos_end = min( pos_start + self.cache_size, len(self.ordered_indices), ) - self.start_pos_for_next_cache = pos_end \ - if self.ordered_prefetch else 0 + self.start_pos_for_next_cache = pos_end if self.ordered_prefetch else 0 total_size = 0 for idx in self.ordered_indices[pos_start: pos_end]: total_size += self.sizes[idx] @@ -358,9 +379,21 @@ class AsrXentDataset(FairseqDataset): """ def __init__( - self, src, src_sizes, tgt: Optional[AliScpCachedDataset] = None, tgt_sizes=None, text=None, - shuffle=True, num_buckets=0, pad_to_multiple=1, seed=1, chunk_width=None, - chunk_left_context=None, chunk_right_context=None, label_delay=0, random_chunking=True, + self, + src, + src_sizes, + tgt: Optional[AliScpCachedDataset] = None, + tgt_sizes=None, + text=None, + shuffle=True, + num_buckets=0, + pad_to_multiple=1, + seed=1, + chunk_width=None, + chunk_left_context=None, + chunk_right_context=None, + label_delay=0, + random_chunking=True, ): self.src = src self.tgt = tgt @@ -375,8 +408,10 @@ def __init__( assert chunk_left_context >= 0 and chunk_right_context >= 0 self.chunk_left_context = chunk_left_context self.chunk_right_context = chunk_right_context - assert (label_delay < 0 and -label_delay <= chunk_right_context) or \ - (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) + assert ( + (label_delay < 0 and -label_delay <= chunk_right_context) + or (label_delay >= 0 and (chunk_width is None or label_delay < chunk_width)) + ) self.label_delay = label_delay self.random_chunking = random_chunking if self.tgt is not None: @@ -385,7 +420,11 @@ def __init__( changed = self._match_src_text() if self.tgt is not None and changed: self._match_src_tgt() - self.sizes = np.vstack((self.src_sizes, self.tgt_sizes)).T if self.tgt_sizes is not None else self.src_sizes + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) if chunk_width is not None: # remove those whose lengths are shorter than chunk_size @@ -406,6 +445,7 @@ def __init__( if num_buckets > 0: from fairseq.data import BucketPadLengthDataset from espresso.data import FeatBucketPadLengthDataset + self.src = FeatBucketPadLengthDataset( self.src, sizes=self.src_sizes, @@ -431,8 +471,7 @@ def __init__( num_tokens = np.vectorize(self.num_tokens, otypes=[np.long]) self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) self.buckets = [ - (None, num_tokens) - for num_tokens in np.unique(self.bucketed_num_tokens) + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) ] else: self.buckets = None @@ -511,7 +550,7 @@ def collater(self, samples, pad_to_length=None): Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of - {'source': source_pad_to_length, 'target': target_pad_to_length} + {"source": source_pad_to_length, "target": target_pad_to_length} to indicate the max length to pad to in source and target respectively. @@ -571,9 +610,7 @@ def ordered_indices(self): if self.buckets is None: # sort by target length, then source length if self.tgt_sizes is not None: - indices = indices[ - np.argsort(self.tgt_sizes[indices], kind="mergesort") - ] + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] else: # sort by bucketed_num_tokens, which is padded_src_len @@ -591,7 +628,7 @@ def prefetch(self, indices): self.tgt.prefetch(indices) def filter_indices_by_size(self, indices, max_sizes): - """ Filter a list of sample indices. Remove those that are longer + """Filter a list of sample indices. Remove those that are longer than specified in max_sizes. Args: diff --git a/espresso/data/encoders/characters_asr.py b/espresso/data/encoders/characters_asr.py index ef424150f..0bd9a48d0 100644 --- a/espresso/data/encoders/characters_asr.py +++ b/espresso/data/encoders/characters_asr.py @@ -3,23 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +from dataclasses import dataclass from typing import List, Optional from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass from espresso.tools.utils import tokenize -@register_bpe('characters_asr') -class CharactersAsr(object): +@dataclass +class CharactersAsrConfig(FairseqDataclass): + pass - @staticmethod - def add_args(parser): - pass +@register_bpe("characters_asr", dataclass=CharactersAsrConfig) +class CharactersAsr(object): def __init__( - self, args, space_symbol="", ends_with_space=True, + self, cfg, space_symbol="", ends_with_space=True, non_lang_syms: Optional[List[str]] = None, ): self.space_symbol = space_symbol diff --git a/espresso/data/feat_text_dataset.py b/espresso/data/feat_text_dataset.py index 0dce55976..64a43eba8 100644 --- a/espresso/data/feat_text_dataset.py +++ b/espresso/data/feat_text_dataset.py @@ -30,8 +30,12 @@ class FeatScpDataset(torch.utils.data.Dataset): """ def __init__( - self, utt_ids: List[str], rxfiles: List[str], utt2num_frames: Optional[List[int]] = None, - seed=1, specaugment_config: Optional[str] = None, + self, + utt_ids: List[str], + rxfiles: List[str], + utt2num_frames: Optional[List[int]] = None, + seed=1, + specaugment_config: Optional[str] = None, ): super().__init__() assert len(utt_ids) == len(rxfiles) diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index ee0348426..214172fe4 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -12,14 +12,17 @@ import logging import os import sys +from argparse import Namespace import numpy as np import torch from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter +from omegaconf import DictConfig try: import kaldi_io @@ -27,12 +30,16 @@ raise ImportError("Please install kaldi_io with: pip install kaldi_io") -def main(args): - assert args.path is not None, "--path required for decoding!" - return _main(args, sys.stderr) +def main(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) -def _main(args, output_file): + assert cfg.common_eval.path is not None, "--path required for decoding!" + return _main(cfg, sys.stderr) + + +def _main(cfg, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -41,41 +48,41 @@ def _main(args, output_file): ) logger = logging.getLogger("espresso.dump_posteriors") - print_options_meaning_changes(args, logger) + print_options_meaning_changes(cfg, logger) - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.max_tokens is None and args.batch_size is None: - args.max_tokens = 12000 - logger.info(args) + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset split - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) - overrides = ast.literal_eval(args.model_overrides) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.path), + utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Load state prior for cross-entropy trained systems decoding - if args.state_prior_file is not None: - prior = torch.from_numpy(kaldi_io.read_vec_flt(args.state_prior_file)) + if cfg.generation.state_prior_file is not None: + prior = torch.from_numpy(kaldi_io.read_vec_flt(cfg.generation.state_prior_file)) else: prior = [] @@ -83,11 +90,11 @@ def _main(args, output_file): for model in models: if model is None: continue - if args.fp16: + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) if isinstance(prior, list) and getattr(model, "state_prior", None) is not None: prior.append(model.state_prior.unsqueeze(0)) @@ -98,7 +105,7 @@ def _main(args, output_file): prior = None if prior is not None: - if args.fp16: + if cfg.common.fp16: prior = prior.half() if use_cuda: prior = prior.cuda() @@ -108,31 +115,30 @@ def _main(args, output_file): # Load dataset (possibly sharded) itr = task.get_batch_iterator( - dataset=task.dataset(args.gen_subset), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), - *[model.max_positions() if hasattr(model, "encoder") - else (None, model.max_positions()) for model in models] + task.max_positions(), *[m.max_positions() for m in models] ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() - generator = task.build_generator(models, args) + generator = task.build_generator(models, cfg.generation) # Generate and dump num_sentences = 0 @@ -153,12 +159,12 @@ def _main(args, output_file): out_lengths = (~padding_mask).long().sum(dim=1).cpu() if padding_mask is not None else None num_processed_frames = sample["ntokens"] gen_timer.stop(num_processed_frames) - num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() + num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() if out_lengths is not None: for i in range(sample["nsentences"]): length = out_lengths[i] - kaldi_io.write_mat(f, lprobs[i, :length, :].cpu().numpy(), key=sample["utt_id"][i]) + kaldi_io.write_mat(f, lprobs[i, : length, :].cpu().numpy(), key=sample["utt_id"][i]) else: for i in range(sample["nsentences"]): kaldi_io.write_mat(f, lprobs[i, :, :].cpu().numpy(), key=sample["utt_id"][i]) @@ -189,9 +195,9 @@ def _main(args, output_file): num_sentences += len(utt_id) for j in range(len(utt_id)): truncated_length = models[0].output_lengths( - task.dataset(args.gen_subset).src_sizes[id[j]] + task.dataset(cfg.dataset.gen_subset).src_sizes[id[j]] ) # length is after possible subsampling by the model - mat = whole_lprobs[j, :truncated_length, :] + mat = whole_lprobs[j, : truncated_length, :] kaldi_io.write_mat(f, mat.numpy(), key=utt_id[j]) logger.info("Dumped {} utterances ({} frames) in {:.1f}s ({:.2f} sentences/s, {:.2f} frames/s)".format( @@ -200,7 +206,7 @@ def _main(args, output_file): return -def print_options_meaning_changes(args, logger): +def print_options_meaning_changes(cfg, logger): """Options that have different meanings than those in the translation task are explained here. """ @@ -209,12 +215,6 @@ def print_options_meaning_changes(args, logger): def cli_main(): parser = options.get_generation_parser(default_task="speech_recognition_hybrid") - parser.add_argument("--apply-log-softmax", action="store_true", - help="Apply log-softmax to the neural network outputs for some " - "systems, e.g., Xent. Otherwise use the raw outputs") - parser.add_argument("--state-prior-file", default=None, type=str, metavar="FILE", - help="state prior file. If provided, use this file instead of " - "that from the checkpoint") args = options.parse_args_and_arch(parser) main(args) diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index dd418ff97..c1f04307f 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -56,9 +56,10 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): assert isinstance(wordlm, FairseqLanguageModel) self.lm_decoder = wordlm.decoder - assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ - callable(self.lm_decoder.masked_copy_incremental_state), \ - 'The wrapped decoder should implement masked_copy_incremental_state()' + assert ( + hasattr(self.lm_decoder, "masked_copy_incremental_state") + and callable(self.lm_decoder.masked_copy_incremental_state) + ), "The wrapped decoder should implement masked_copy_incremental_state()" self.oov_penalty = oov_penalty self.open_vocab = open_vocab self.zero = 1e-10 # a sufficiently small value to avoid the log(0) issue @@ -76,7 +77,7 @@ def __init__(self, wordlm, subword_dict, oov_penalty=1e-4, open_vocab=True): self.subword_vocab_size = len(subword_dict) def tokenizer(x): - return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ") self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) def max_out_degree(node): @@ -92,18 +93,17 @@ def max_out_degree(node): @torch.no_grad() def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): - assert incremental_state is not None, \ - 'this model is for incremental decoding only' + assert incremental_state is not None, "this model is for incremental decoding only" prev_output_tokens = prev_output_tokens[:, -1:] bsz = prev_output_tokens.size(0) batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) - cached_state = self.lm_decoder.get_incremental_state(incremental_state, 'cached_state') + cached_state = self.lm_decoder.get_incremental_state(incremental_state, "cached_state") if cached_state is None: # it is the first time step assert (prev_output_tokens == self.subword_eos_idx).all(), \ - 'expecting the input to the first time step to be ' + "expecting the input to the first time step to be " w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) lm_probs = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), @@ -112,8 +112,8 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): cumsum_probs = torch.cumsum(lm_probs, dim=-1) # B x 1 x V nodes = [self.lexroot] * bsz else: - cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') - nodes = self.get_incremental_state(incremental_state, 'nodes') + cumsum_probs = self.get_incremental_state(incremental_state, "cumsum_probs") + nodes = self.get_incremental_state(incremental_state, "nodes") assert len(nodes) == bsz w = prev_output_tokens.new([ node.word_idx if node is not None and node.word_idx >= 0 else @@ -129,8 +129,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): self.lm_decoder.masked_copy_incremental_state( incremental_state, old_cached_state, batch_space_mask, ) # restore those not masked - cumsum_probs[batch_space_mask] = \ + cumsum_probs[batch_space_mask] = ( torch.cumsum(lm_probs, dim=-1)[batch_space_mask] + ) tokens_list = prev_output_tokens.squeeze(-1).tolist() for i in range(bsz): if tokens_list[i] == self.subword_space_idx: @@ -142,8 +143,8 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): else: # no path in the tree nodes[i] = None - self.set_incremental_state(incremental_state, 'cumsum_probs', cumsum_probs) - self.set_incremental_state(incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, "cumsum_probs", cumsum_probs) + self.set_incremental_state(incremental_state, "nodes", nodes) # initialize out_probs (B x 1 x V) if self.open_vocab: @@ -155,8 +156,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): # set the probability of emitting to 0 if prev_output_tokens # is or , and that of emitting to 0 if # prev_output_tokens is not - batch_space_eos_mask = batch_space_mask | \ + batch_space_eos_mask = ( + batch_space_mask | prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + ) out_probs[batch_space_eos_mask, :, self.subword_space_idx] = self.zero out_probs[~batch_space_mask, :, self.subword_eos_idx] = self.zero # set transition probability to 1 for those whose node is out of the @@ -164,13 +167,13 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_node_none_mask = batch_space_mask.new( [node is None for node in nodes] ) - out_probs[batch_node_none_mask] = 1. + out_probs[batch_node_none_mask] = 1.0 else: # set out_probs to 0 out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], self.zero) # compute parent probabilities for those whose node is not None - sum_probs = cumsum_probs.new_full([bsz, 1], 1.) # default for root node + sum_probs = cumsum_probs.new_full([bsz, 1], 1.0) # default for root node left_ranges, right_ranges, batch_node_not_root_mask = [], [], [] for node in nodes: if node is not None and node.word_set is not None: @@ -243,8 +246,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): out_logprobs = out_probs.clamp(min=self.zero).log_() # assign log-probs of emitting word to that of emitting subword - out_logprobs[batch_space_mask, :, self.subword_eos_idx] = \ + out_logprobs[batch_space_mask, :, self.subword_eos_idx] = ( lm_probs.log_()[batch_space_mask, :, self.word_eos_idx] + ) # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in @@ -254,16 +258,16 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') + cumsum_probs = self.get_incremental_state(incremental_state, "cumsum_probs") if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - self.set_incremental_state(incremental_state, 'cumsum_probs', new_cumsum_probs) + self.set_incremental_state(incremental_state, "cumsum_probs", new_cumsum_probs) - nodes = self.get_incremental_state(incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, "nodes") if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - self.set_incremental_state(incremental_state, 'nodes', new_nodes) + self.set_incremental_state(incremental_state, "nodes", new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" @@ -301,9 +305,10 @@ def __init__( assert isinstance(wordlm, FairseqLanguageModel) self.wordlm_decoder = wordlm.decoder - assert hasattr(self.wordlm_decoder, 'masked_copy_incremental_state') and \ - callable(self.wordlm_decoder.masked_copy_incremental_state), \ - 'The wrapped decoder should implement masked_copy_incremental_state()' + assert ( + hasattr(self.wordlm_decoder, "masked_copy_incremental_state") and + callable(self.wordlm_decoder.masked_copy_incremental_state) + ), "The wrapped decoder should implement masked_copy_incremental_state()" assert isinstance(subwordlm, FairseqLanguageModel) self.subwordlm_decoder = subwordlm.decoder self.subwordlm_weight = subwordlm_weight @@ -323,13 +328,12 @@ def __init__( self.subword_vocab_size = len(subword_dict) def tokenizer(x): - return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ") self.lexroot = lexical_prefix_tree(word_dict, subword_dict, tokenizer) @torch.no_grad() def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): - assert incremental_state is not None, \ - 'this model is for incremental decoding only' + assert incremental_state is not None, "this model is for incremental decoding only" prev_output_tokens = prev_output_tokens[:, -1:] bsz = prev_output_tokens.size(0) @@ -337,16 +341,16 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_not_space_mask = ~batch_space_mask wordlm_cached_state = self.wordlm_decoder.get_incremental_state( - incremental_state, 'cached_state', + incremental_state, "cached_state", ) subwordlm_cached_state = self.subwordlm_decoder.get_incremental_state( - incremental_state, 'cached_state', + incremental_state, "cached_state", ) if wordlm_cached_state is None: # it is the first time step assert subwordlm_cached_state is None assert (prev_output_tokens == self.subword_eos_idx).all(), \ - 'expecting the input to the first time step to be ' + "expecting the input to the first time step to be " w = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) wordlm_logprobs = self.wordlm_decoder.get_normalized_probs( self.wordlm_decoder(w, incremental_state=incremental_state), @@ -362,10 +366,10 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): subword_cumlogprobs = out_logprobs.new_zeros(sw.size()) nodes = [self.lexroot] * bsz else: - wordlm_logprobs = self.get_incremental_state(incremental_state, 'wordlm_logprobs') - out_logprobs = self.get_incremental_state(incremental_state, 'out_logprobs') - subword_cumlogprobs = self.get_incremental_state(incremental_state, 'subword_cumlogprobs') - nodes = self.get_incremental_state(incremental_state, 'nodes') + wordlm_logprobs = self.get_incremental_state(incremental_state, "wordlm_logprobs") + out_logprobs = self.get_incremental_state(incremental_state, "out_logprobs") + subword_cumlogprobs = self.get_incremental_state(incremental_state, "subword_cumlogprobs") + nodes = self.get_incremental_state(incremental_state, "nodes") assert len(nodes) == bsz w = prev_output_tokens.new([ node.word_idx if node is not None and node.word_idx >= 0 else @@ -403,15 +407,17 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_is_child_mask.append(False) token_idx = prev_output_tokens.new(token_idx).unsqueeze(-1) # b x 1 x 1 if self.open_vocab: - subword_cumlogprobs[batch_space_mask] = 0. + subword_cumlogprobs[batch_space_mask] = 0.0 assert batch_not_space_mask.sum().item() == len(token_idx) - subword_cumlogprobs[batch_not_space_mask] += \ + subword_cumlogprobs[batch_not_space_mask] += ( out_logprobs[batch_not_space_mask].gather(-1, token_idx).squeeze(-1) + ) else: - subword_cumlogprobs[~batch_is_child_mask] = 0. + subword_cumlogprobs[~batch_is_child_mask] = 0.0 assert batch_is_child_mask.sum().item() == len(token_idx) - subword_cumlogprobs[batch_is_child_mask] += \ + subword_cumlogprobs[batch_is_child_mask] += ( out_logprobs[batch_is_child_mask].gather(-1, token_idx).squeeze(-1) + ) out_logprobs = self.subwordlm_decoder.get_normalized_probs( self.subwordlm_decoder(prev_output_tokens, incremental_state=incremental_state), @@ -423,9 +429,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): batch_oov_mask = batch_not_space_mask & ~batch_is_child_mask out_logprobs[batch_oov_mask] = self.logzero - self.set_incremental_state(incremental_state, 'wordlm_logprobs', wordlm_logprobs) - self.set_incremental_state(incremental_state, 'subword_cumlogprobs', subword_cumlogprobs) - self.set_incremental_state(incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, "wordlm_logprobs", wordlm_logprobs) + self.set_incremental_state(incremental_state, "subword_cumlogprobs", subword_cumlogprobs) + self.set_incremental_state(incremental_state, "nodes", nodes) # apply word-level probabilies for emitting w = prev_output_tokens.new([ @@ -443,16 +449,18 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): # set the probability of emitting to 0 if prev_output_tokens is # or , and that of emitting to 0 if prev_output_tokens # is not - batch_space_eos_mask = batch_space_mask | \ - prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + batch_space_eos_mask = ( + batch_space_mask | prev_output_tokens.squeeze(-1).eq(self.subword_eos_idx) + ) out_logprobs[batch_space_eos_mask, :, self.subword_space_idx] = self.logzero out_logprobs[~batch_space_mask, :, self.subword_eos_idx] = self.logzero # add log-probs of emitting word to that of emitting subword - out_logprobs[batch_space_mask, :, self.subword_eos_idx] += \ + out_logprobs[batch_space_mask, :, self.subword_eos_idx] += ( wordlm_logprobs[batch_space_mask, :, self.word_eos_idx] + ) - self.set_incremental_state(incremental_state, 'out_logprobs', out_logprobs) + self.set_incremental_state(incremental_state, "out_logprobs", out_logprobs) # note that here we return log-probs rather than logits, and the second # element is None, which is usually a tensor of attention weights in @@ -462,17 +470,17 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - for state_name in ['wordlm_logprobs', 'out_logprobs', 'subword_cumlogprobs']: + for state_name in ["wordlm_logprobs", "out_logprobs", "subword_cumlogprobs"]: state = self.get_incremental_state(incremental_state, state_name) if state is not None: new_state = state.index_select(0, new_order) self.set_incremental_state(incremental_state, state_name, new_state) - nodes = self.get_incremental_state(incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, "nodes") if nodes is not None: new_order_list = new_order.tolist() new_nodes = [nodes[i] for i in new_order_list] - self.set_incremental_state(incremental_state, 'nodes', new_nodes) + self.set_incremental_state(incremental_state, "nodes", new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" diff --git a/espresso/models/lstm_lm.py b/espresso/models/lstm_lm.py index 11e0764ca..7d0f7e721 100644 --- a/espresso/models/lstm_lm.py +++ b/espresso/models/lstm_lm.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II from typing import Optional from fairseq import utils @@ -15,6 +14,7 @@ register_model_architecture, ) from fairseq.models.lstm import Embedding +from omegaconf import II from espresso.models.speech_lstm import SpeechLSTMDecoder from espresso.tasks.speech_recognition import SpeechRecognitionEspressoTask @@ -44,7 +44,7 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): decoder_out_embed_dim: int = field( default=650, metadata={"help": "decoder output embedding dimension"} ) - decoder_rnn_residual: lambda x: utils.eval_bool(x) = field( + decoder_rnn_residual: bool = field( default=False, metadata={ "help": "create residual connections for rnn decoder layers " @@ -59,7 +59,7 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): "Must be used with adaptive_loss criterion" }, ) - share_embed: lambda x: utils.eval_bool(x) = field( + share_embed: bool = field( default=False, metadata={"help": "share input and output embeddings"} ) is_wordlm: bool = field( @@ -79,18 +79,59 @@ class LSTMLanguageModelEspressoConfig(FairseqDataclass): metadata={"help": "dropout probability for decoder output"} ) # TODO common var add to parent - add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") - tpu: bool = II("params.common.tpu") + tpu: bool = II("common.tpu") -@register_model("lstm_lm_espresso", dataclass=LSTMLanguageModelEspressoConfig) +@register_model("lstm_lm_espresso") class LSTMLanguageModelEspresso(FairseqLanguageModel): def __init__(self, decoder, args): super().__init__(decoder) self.is_wordlm = args.is_wordlm + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--dropout", type=float, metavar="D", + help="dropout probability") + parser.add_argument("--decoder-embed-dim", type=int, metavar="N", + help="decoder embedding dimension") + parser.add_argument("--decoder-embed-path", type=str, metavar="STR", + help="path to pre-trained decoder embedding") + parser.add_argument("--decoder-freeze-embed", action="store_true", + help="freeze decoder embeddings") + parser.add_argument("--decoder-hidden-size", type=int, metavar="N", + help="decoder hidden size") + parser.add_argument("--decoder-layers", type=int, metavar="N", + help="number of decoder layers") + parser.add_argument("--decoder-out-embed-dim", type=int, metavar="N", + help="decoder output embedding dimension") + parser.add_argument("--decoder-rnn-residual", + type=lambda x: utils.eval_bool(x), + help="create residual connections for rnn decoder " + "layers (starting from the 2nd layer), i.e., the actual " + "output of such layer is the sum of its input and output") + parser.add_argument("--adaptive-softmax-cutoff", metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion") + parser.add_argument("--share-embed", + type=lambda x: utils.eval_bool(x), + help="share input and output embeddings") + parser.add_argument("--is-wordlm", action="store_true", + help="whether it is word LM or subword LM. Only " + "relevant for ASR decoding with LM, and it determines " + "how the underlying decoder instance gets the dictionary " + "from the task instance when calling cls.build_model()") + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument("--decoder-dropout-in", type=float, metavar="D", + help="dropout probability for decoder input embedding") + parser.add_argument("--decoder-dropout-out", type=float, metavar="D", + help="dropout probability for decoder output") + # fmt: on + @classmethod def build_model(cls, args, task): """Build a new model instance.""" @@ -100,7 +141,9 @@ def build_model(cls, args, task): if getattr(args, "max_target_positions", None) is not None: max_target_positions = args.max_target_positions else: - max_target_positions = getattr(args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS) + max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -121,9 +164,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): pretrained_decoder_embed = None if args.decoder_embed_path: pretrained_decoder_embed = load_pretrained_embedding_from_file( - args.decoder_embed_path, - dictionary, - args.decoder_embed_dim + args.decoder_embed_path, dictionary, args.decoder_embed_dim ) # one last double check of parameter combinations if args.share_embed and ( @@ -150,7 +191,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_embed, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == "adaptive_loss" else None + if args.criterion == "adaptive_loss" + else None ), max_target_positions=max_target_positions, ) @@ -181,39 +223,33 @@ def lstm_lm_wsj(args): @register_model_architecture("lstm_lm_espresso", "lstm_lm_librispeech") def lstm_lm_librispeech(args): - args.dropout = 0.0 - args.decoder_embed_dim = 800 - args.decoder_hidden_size = 800 - args.decoder_layers = 4 - args.decoder_out_embed_dim = 800 - args.decoder_dropout_in = args.dropout - args.decoder_dropout_out = args.dropout - args.share_embed = True + args.dropout = getattr(args, "dropout", 0.0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 800) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 800) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 800) + args.share_embed = getattr(args, "share_embed", True) base_lm_architecture(args) @register_model_architecture("lstm_lm_espresso", "lstm_lm_swbd") def lstm_lm_swbd(args): - args.dropout = 0.3 - args.decoder_embed_dim = 1800 - args.decoder_hidden_size = 1800 - args.decoder_layers = 3 - args.decoder_out_embed_dim = 1800 - args.decoder_dropout_in = args.dropout - args.decoder_dropout_out = args.dropout - args.share_embed = True + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1800) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1800) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1800) + args.share_embed = getattr(args, "share_embed", True) base_lm_architecture(args) @register_model_architecture("lstm_lm_espresso", "lstm_wordlm_wsj") def lstm_wordlm_wsj(args): - args.dropout = 0.35 - args.decoder_embed_dim = 1200 - args.decoder_hidden_size = 1200 - args.decoder_layers = 3 - args.decoder_out_embed_dim = 1200 - args.decoder_dropout_in = args.dropout - args.decoder_dropout_out = args.dropout - args.share_embed = True + args.dropout = getattr(args, "dropout", 0.35) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1200) + args.decoder_hidden_size = getattr(args, "decoder_hidden_size", 1200) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1200) + args.share_embed = getattr(args, "share_embed", True) args.is_wordlm = True base_lm_architecture(args) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index a644b2536..4764ffd89 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -15,8 +15,8 @@ from fairseq.models import ( FairseqDecoder, FairseqEncoder, - FairseqIncrementalDecoder, FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, register_model, register_model_architecture, ) @@ -135,8 +135,12 @@ def build_model(cls, args, task): # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) - max_source_positions = getattr(args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS) - max_target_positions = getattr(args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS) + max_source_positions = getattr( + args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS + ) + max_target_positions = getattr( + args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS + ) def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): num_embeddings = len(dictionary) @@ -201,7 +205,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, residual=args.encoder_rnn_residual, - src_bucketed=(getattr(task.args, "num_batch_buckets", 0) > 0), + src_bucketed=(getattr(task.cfg, "num_batch_buckets", 0) > 0), max_source_positions=max_source_positions, ) decoder = SpeechLSTMDecoder( @@ -221,7 +225,8 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): share_input_output_embed=args.share_decoder_input_output_embed, adaptive_softmax_cutoff=( utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) - if args.criterion == "adaptive_loss" else None + if args.criterion == "adaptive_loss" + else None ), max_target_positions=max_target_positions, scheduled_sampling_rate_scheduler=scheduled_sampling_rate_scheduler, @@ -247,8 +252,10 @@ def forward( ): encoder_out = self.encoder(src_tokens, src_lengths=src_lengths) decoder_out = self.decoder( - prev_output_tokens, encoder_out=encoder_out, - incremental_state=incremental_state, epoch=epoch, + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + epoch=epoch, ) return decoder_out @@ -260,14 +267,16 @@ def max_positions(self): """Maximum length supported by the model.""" return ( self.encoder.max_positions(), - self.decoder.max_positions() if self.pretrained_lm is None else - min(self.decoder.max_positions(), self.pretrained_lm.max_positions()), + self.decoder.max_positions() if self.pretrained_lm is None + else min(self.decoder.max_positions(), self.pretrained_lm.max_positions()), ) def max_decoder_positions(self): """Maximum length supported by the decoder.""" - return self.decoder.max_positions() if self.pretrained_lm is None else \ - min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) + return ( + self.decoder.max_positions() if self.pretrained_lm is None + else min(self.decoder.max_positions(), self.pretrained_lm.max_positions()) + ) class ConvBNReLU(nn.Module): @@ -327,29 +336,46 @@ def forward(self, src, src_lengths): class SpeechLSTMEncoder(FairseqEncoder): """LSTM encoder.""" def __init__( - self, conv_layers_before=None, input_size=83, hidden_size=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - residual=False, left_pad=False, padding_value=0., src_bucketed=False, + self, + conv_layers_before=None, + input_size=83, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + residual=False, + left_pad=False, + padding_value=0.0, + src_bucketed=False, max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__(None) # no src dictionary self.conv_layers_before = conv_layers_before self.num_layers = num_layers - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.bidirectional = bidirectional self.hidden_size = hidden_size self.residual = residual self.max_source_positions = max_source_positions - self.lstm = nn.ModuleList([ - LSTM( - input_size=input_size if layer == 0 else 2 * hidden_size if self.bidirectional else hidden_size, - hidden_size=hidden_size, - bidirectional=bidirectional, - ) - for layer in range(num_layers) - ]) + self.lstm = nn.ModuleList( + [ + LSTM( + input_size=input_size if layer == 0 + else 2 * hidden_size if self.bidirectional + else hidden_size, + hidden_size=hidden_size, + bidirectional=bidirectional, + ) + for layer in range(num_layers) + ] + ) self.left_pad = left_pad self.padding_value = padding_value self.src_bucketed = src_bucketed @@ -359,8 +385,10 @@ def __init__( self.output_units *= 2 def output_lengths(self, in_lengths): - return in_lengths if self.conv_layers_before is None \ + return ( + in_lengths if self.conv_layers_before is None else self.conv_layers_before.output_lengths(in_lengths) + ) def forward( self, @@ -392,8 +420,10 @@ def forward( if self.conv_layers_before is not None: x, src_lengths, padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: - x, padding_mask = src_tokens, \ + x, padding_mask = ( + src_tokens, ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) + ) bsz, seqlen = x.size(0), x.size(1) @@ -422,7 +452,9 @@ def forward( packed_outs, (_, _) = self.lstm[i](packed_x, (h0, c0)) # unpack outputs and apply dropout - x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value*1.0) + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_value * 1.0 + ) if i < len(self.lstm) - 1: # not applying dropout for the last layer x = self.dropout_out_module(x) x = x + prev_x if self.residual and i > 0 else x @@ -432,7 +464,8 @@ def forward( return EncoderOut( encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B + encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() + else None, # T x B encoder_embedding=None, encoder_states=None, src_tokens=None, @@ -469,16 +502,32 @@ def max_positions(self): class SpeechLSTMDecoder(FairseqIncrementalDecoder): """LSTM decoder.""" def __init__( - self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, encoder_output_units=0, - attn_type=None, attn_dim=0, need_attn=False, residual=False, pretrained_embed=None, - share_input_output_embed=False, adaptive_softmax_cutoff=None, + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + encoder_output_units=0, + attn_type=None, + attn_dim=0, + need_attn=False, + residual=False, + pretrained_embed=None, + share_input_output_embed=False, + adaptive_softmax_cutoff=None, max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, scheduled_sampling_rate_scheduler=None, ): super().__init__(dictionary) - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed if attn_type is None or attn_type.lower() == "none": @@ -500,13 +549,16 @@ def __init__( self.encoder_output_units = encoder_output_units - self.layers = nn.ModuleList([ - LSTMCell( - input_size=encoder_output_units + (embed_dim if layer == 0 else hidden_size), - hidden_size=hidden_size, - ) - for layer in range(num_layers) - ]) + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=encoder_output_units + + (embed_dim if layer == 0 else hidden_size), + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) if attn_type is None or attn_type.lower() == "none": self.attention = None @@ -527,7 +579,10 @@ def __init__( if adaptive_softmax_cutoff is not None: # setting adaptive_softmax dropout to dropout_out for now but can be redefined self.adaptive_softmax = AdaptiveSoftmax( - num_embeddings, hidden_size, adaptive_softmax_cutoff, dropout=dropout_out, + num_embeddings, + hidden_size, + adaptive_softmax_cutoff, + dropout=dropout_out, ) elif not self.share_input_output_embed: self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) @@ -641,8 +696,10 @@ def extract_features( zero_state = x.new_zeros(bsz, self.hidden_size) prev_hiddens = [zero_state for i in range(self.num_layers)] prev_cells = [zero_state for i in range(self.num_layers)] - input_feed = x.new_zeros(bsz, self.encoder_output_units) \ - if encoder_out is not None else None + input_feed = ( + x.new_zeros(bsz, self.encoder_output_units) if encoder_out is not None + else None + ) attn_scores = x.new_zeros(srclen, seqlen, bsz) if encoder_out is not None else None outs = [] @@ -746,7 +803,9 @@ def get_cached_state( assert prev_cells_ is not None prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] prev_cells = [prev_cells_[j] for j in range(self.num_layers)] - input_feed = cached_state["input_feed"] # can be None for decoder-only language models + input_feed = cached_state[ + "input_feed" + ] # can be None for decoder-only language models return prev_hiddens, prev_cells, input_feed def reorder_incremental_state( @@ -767,7 +826,7 @@ def reorder_incremental_state( "prev_hiddens": torch.stack(prev_hiddens), "prev_cells": torch.stack(prev_cells), "input_feed": input_feed, - } + }, ) self.set_incremental_state(incremental_state, "cached_state", cached_state_new), return @@ -777,8 +836,9 @@ def masked_copy_incremental_state(self, incremental_state, another_cached_state, assert another_cached_state is None or len(another_cached_state) == 0 return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) - another_prev_hiddens, another_prev_cells, another_input_feed = \ + another_prev_hiddens, another_prev_cells, another_input_feed = ( another_cached_state[0], another_cached_state[1], another_cached_state[2] + ) def mask_copy_state(state: Optional[Tensor], another_state: Optional[Tensor]): if state is None: @@ -807,7 +867,7 @@ def mask_copy_state(state: Optional[Tensor], another_state: Optional[Tensor]): "prev_hiddens": torch.stack(prev_hiddens_new), "prev_cells": torch.stack(prev_cells_new), "input_feed": input_feed_new, - } + }, ) self.set_incremental_state(incremental_state, "cached_state", cached_state_new) diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py index c7eed3833..565baa2cd 100644 --- a/espresso/models/speech_lstm_encoder_model.py +++ b/espresso/models/speech_lstm_encoder_model.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import logging from typing import Optional @@ -18,6 +19,7 @@ ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import Linear +from omegaconf import DictConfig from espresso.models.speech_lstm import ConvBNReLU, SpeechLSTMEncoder import espresso.tools.utils as speech_utils @@ -72,7 +74,9 @@ def build_model(cls, args, task): """Build a new model instance.""" # make sure that all args are properly defaulted (in case there are any new ones) base_architecture(args) - max_source_positions = getattr(args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS) + max_source_positions = getattr( + args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS + ) out_channels = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_channels, type=int) kernel_sizes = speech_utils.eval_str_nested_list_or_tuple(args.encoder_conv_kernel_sizes, type=int) @@ -106,7 +110,7 @@ def build_model(cls, args, task): dropout_out=args.encoder_rnn_dropout_out, bidirectional=args.encoder_rnn_bidirectional, residual=args.encoder_rnn_residual, - src_bucketed=(getattr(task.args, "num_batch_buckets", 0) > 0), + src_bucketed=(getattr(task.cfg, "num_batch_buckets", 0) > 0), num_targets=getattr(task, "num_targets", None), # targets for encoder-only model chunk_width=getattr(task, "chunk_width", None), chunk_left_context=getattr(task, "chunk_left_context", 0), @@ -140,43 +144,81 @@ def state_dict(self): state_dict["state_prior"] = self.state_prior return state_dict - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): state_dict_subset = state_dict.copy() self.state_prior = state_dict.get("state_prior", None) if "state_prior" in state_dict: self.state_prior = state_dict["state_prior"] del state_dict_subset["state_prior"] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict( + state_dict_subset, strict=strict, model_cfg=model_cfg, args=args + ) class SpeechChunkLSTMEncoder(SpeechLSTMEncoder): """LSTM encoder.""" def __init__( - self, conv_layers_before=None, input_size=83, hidden_size=512, - num_layers=1, dropout_in=0.1, dropout_out=0.1, bidirectional=False, - residual=False, left_pad=False, padding_value=0., src_bucketed=False, - num_targets=None, chunk_width=20, chunk_left_context=0, training_stage=True, + self, + conv_layers_before=None, + input_size=83, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + residual=False, + left_pad=False, + padding_value=0.0, + src_bucketed=False, + num_targets=None, + chunk_width=20, + chunk_left_context=0, + training_stage=True, max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, ): super().__init__( - conv_layers_before=conv_layers_before, input_size=input_size, hidden_size=hidden_size, - num_layers=num_layers, dropout_in=dropout_in, dropout_out=dropout_out, - bidirectional=bidirectional, residual=residual, left_pad=left_pad, - padding_value=padding_value, src_bucketed=src_bucketed, max_source_positions=max_source_positions, + conv_layers_before=conv_layers_before, + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout_in=dropout_in, + dropout_out=dropout_out, + bidirectional=bidirectional, + residual=residual, + left_pad=left_pad, + padding_value=padding_value, + src_bucketed=src_bucketed, + max_source_positions=max_source_positions, + ) + receptive_field_radius = ( + sum(conv.padding[0] for conv in conv_layers_before.convolutions) + if conv_layers_before is not None + else 0 ) - receptive_field_radius = sum(conv.padding[0] for conv in conv_layers_before.convolutions) \ - if conv_layers_before is not None else 0 assert chunk_width is None or chunk_width > 0 - assert (conv_layers_before is None and chunk_left_context >= 0) or \ - (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + assert ( + (conv_layers_before is None and chunk_left_context >= 0) + or (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + ) self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 - self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ - if chunk_width is not None else None + self.out_chunk_end = ( + self.output_lengths(chunk_left_context + chunk_width) if chunk_width is not None + else None + ) self.training_stage = training_stage # only for encoder-only model - self.fc_out = Linear(self.output_units, num_targets, dropout=self.dropout_out_module.p) \ - if num_targets is not None else None + self.fc_out = ( + Linear(self.output_units, num_targets, dropout=self.dropout_out_module.p) + if num_targets is not None + else None + ) def forward( self, @@ -214,7 +256,8 @@ def forward( return EncoderOut( encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B + encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() + else None, # T x B encoder_embedding=None, encoder_states=None, src_tokens=None, diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index 8d22a5e2e..214a4d0a7 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import logging from typing import Optional @@ -21,6 +22,7 @@ from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import Linear from fairseq.modules import FairseqDropout +from omegaconf import DictConfig import espresso.tools.utils as speech_utils @@ -125,13 +127,21 @@ def state_dict(self): state_dict["state_prior"] = self.state_prior return state_dict - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): state_dict_subset = state_dict.copy() self.state_prior = state_dict.get("state_prior", None) if "state_prior" in state_dict: self.state_prior = state_dict["state_prior"] del state_dict_subset["state_prior"] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict( + state_dict_subset, strict=strict, model_cfg=model_cfg, args=args + ) class TdnnBNReLU(nn.Module): @@ -192,8 +202,12 @@ def __init__( dilations = [dilations] * num_layers else: assert len(dilations) == num_layers - self.dropout_in_module = FairseqDropout(dropout_in, module_name=self.__class__.__name__) - self.dropout_out_module = FairseqDropout(dropout_out, module_name=self.__class__.__name__) + self.dropout_in_module = FairseqDropout( + dropout_in, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out, module_name=self.__class__.__name__ + ) self.residual = residual self.tdnn = nn.ModuleList([ @@ -206,15 +220,22 @@ def __init__( ]) receptive_field_radius = sum(layer.padding for layer in self.tdnn) - assert chunk_width is None or (chunk_width > 0 and chunk_left_context >= receptive_field_radius) + assert ( + chunk_width is None + or (chunk_width > 0 and chunk_left_context >= receptive_field_radius) + ) if ( chunk_width is not None and chunk_width > 0 and chunk_left_context > receptive_field_radius ): - logger.warning("chunk_{{left,right}}_context can be reduced to {}".format(receptive_field_radius)) + logger.warning( + "chunk_{{left,right}}_context can be reduced to {}".format(receptive_field_radius) + ) self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 - self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ - if chunk_width is not None else None + self.out_chunk_end = ( + self.output_lengths(chunk_left_context + chunk_width) if chunk_width is not None + else None + ) self.training_stage = training_stage self.fc_out = Linear(hidden_sizes[-1], output_size, dropout=self.dropout_out_module.p) diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index 5386b7299..a381b1a77 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) -@register_model('speech_transformer') +@register_model("speech_transformer") class SpeechTransformerModel(TransformerModel): """ Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) @@ -246,7 +246,9 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con super(TransformerEncoder, self).__init__(None) # no src dictionary self.register_buffer("version", torch.Tensor([3])) - self.dropout_module = FairseqDropout(args.dropout, module_name=self.__class__.__name__) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) self.encoder_layerdrop = args.encoder_layerdrop embed_dim = args.encoder_embed_dim @@ -257,7 +259,7 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con self.embed_positions = ( PositionalEmbedding( - self.output_lengths(args.max_source_positions), + self.output_lengths(self.max_source_positions), embed_dim, 0, learned=args.encoder_learned_pos, @@ -297,8 +299,10 @@ def __init__(self, args, conv_layers_before=None, input_size=83, transformer_con self.transformer_context = transformer_context def output_lengths(self, in_lengths): - return in_lengths if self.conv_layers_before is None \ + return ( + in_lengths if self.conv_layers_before is None else self.conv_layers_before.output_lengths(in_lengths) + ) def get_attn_mask(self, in_lengths): """ @@ -360,8 +364,10 @@ def forward( if self.conv_layers_before is not None: x, src_lengths, encoder_padding_mask = self.conv_layers_before(src_tokens, src_lengths) else: - x, encoder_padding_mask = src_tokens, \ + x, encoder_padding_mask = ( + src_tokens, ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)) + ) x = self.dropout_module(x) if self.fc0 is not None: @@ -579,6 +585,15 @@ def base_architecture(args): args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) @register_model_architecture("speech_transformer", "speech_transformer_wsj") diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index 989a35986..9fccf6813 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace import logging from typing import Optional @@ -18,6 +19,7 @@ ) from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import Linear +from omegaconf import DictConfig from espresso.models.speech_lstm import ConvBNReLU from espresso.models.speech_transformer import SpeechTransformerEncoder @@ -70,6 +72,11 @@ def add_args(parser): "can be None or a tuple of two non-negative integers/None") parser.add_argument("--no-token-positional-embeddings", action="store_true", help="if set, disables positional embeddings (outside self attention)") + parser.add_argument("--layernorm-embedding", action="store_true", + help="add layernorm to embedding") + parser.add_argument("--checkpoint-activations", action="store_true", + help="checkpoint activations at each layer, which saves GPU " + "memory usage at the cost of some additional compute") # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument("--encoder-layerdrop", type=float, metavar="D", default=0, help="LayerDrop probability for encoder") @@ -191,13 +198,21 @@ def state_dict(self): state_dict["state_prior"] = self.state_prior return state_dict - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): state_dict_subset = state_dict.copy() self.state_prior = state_dict.get("state_prior", None) if "state_prior" in state_dict: self.state_prior = state_dict["state_prior"] del state_dict_subset["state_prior"] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict( + state_dict_subset, strict=strict, model_cfg=model_cfg, args=args + ) class SpeechChunkTransformerEncoder(SpeechTransformerEncoder): @@ -210,19 +225,30 @@ def __init__( args, conv_layers_before=conv_layers_before, input_size=input_size, transformer_context=transformer_context, ) - receptive_field_radius = sum(conv.padding[0] for conv in conv_layers_before.convolutions) \ - if conv_layers_before is not None else 0 + receptive_field_radius = ( + sum(conv.padding[0] for conv in conv_layers_before.convolutions) + if conv_layers_before is not None + else 0 + ) assert chunk_width is None or chunk_width > 0 - assert (conv_layers_before is None and chunk_left_context >= 0) or \ - (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + assert ( + (conv_layers_before is None and chunk_left_context >= 0) + or (conv_layers_before is not None and chunk_left_context >= receptive_field_radius) + ) self.out_chunk_begin = self.output_lengths(chunk_left_context + 1) - 1 - self.out_chunk_end = self.output_lengths(chunk_left_context + chunk_width) \ - if chunk_width is not None else None + self.out_chunk_end = ( + self.output_lengths(chunk_left_context + chunk_width) + if chunk_width is not None + else None + ) self.training_stage = training_stage # only for encoder-only model - self.fc_out = Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout_module.p) \ - if num_targets is not None else None + self.fc_out = ( + Linear(args.encoder_embed_dim, num_targets, dropout=self.dropout_module.p) + if num_targets is not None + else None + ) def forward( self, @@ -360,6 +386,13 @@ def base_architecture(args): ) args.adaptive_input = getattr(args, "adaptive_input", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) @register_model_architecture("speech_transformer_encoder_model", "speech_transformer_encoder_model_wsj") diff --git a/espresso/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py index 00cabc3f3..6be342a38 100644 --- a/espresso/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -57,9 +57,10 @@ def __init__(self, super().__init__(word_lm.decoder.dictionary) self.lm_decoder: FairseqIncrementalDecoder = word_lm.decoder - assert hasattr(self.lm_decoder, 'masked_copy_incremental_state') and \ - callable(self.lm_decoder.masked_copy_incremental_state), \ - 'The wrapped decoder should implement masked_copy_incremental_state()' + assert ( + hasattr(self.lm_decoder, "masked_copy_incremental_state") + and callable(self.lm_decoder.masked_copy_incremental_state) + ), "The wrapped decoder should implement masked_copy_incremental_state()" self.oov_penalty = oov_penalty self.open_vocab = open_vocab @@ -76,7 +77,7 @@ def __init__(self, self.subword_vocab_size = len(subword_dict) def tokenizer(x: str) -> List[str]: - return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(' ') + return tokenize(x, non_lang_syms=subword_dict.non_lang_syms).split(" ") self.tree = TensorizedPrefixTree.build(word_dict, subword_dict, tokenizer) assert self.tree.max_out_degree() <= self.subword_vocab_size @@ -86,7 +87,7 @@ def forward(self, prev_output_tokens: torch.Tensor, # Z_Tokens[Batch, SeqLength] encoder_out=None, incremental_state: Dict[str, Any] = None): - assert incremental_state is not None, 'This model is for incremental decoding only' + assert incremental_state is not None, "This model is for incremental decoding only" prev_output_tokens = prev_output_tokens[:, -1:] # Z_Tokens[Batch, Len=1] bsz = prev_output_tokens.size(0) @@ -95,11 +96,11 @@ def forward(self, # Move the batched state to the next state according to the automaton batch_space_mask = prev_output_tokens.squeeze(-1).eq(self.subword_space_idx) # B[Batch] - cached_state = self.lm_decoder.get_incremental_state(incremental_state, 'cached_state') + cached_state = self.lm_decoder.get_incremental_state(incremental_state, "cached_state") if cached_state is None: # First step assert (prev_output_tokens == self.subword_eos_idx).all(), \ - 'expecting the input to the first time step to be ' + "expecting the input to the first time step to be " w: torch.Tensor = prev_output_tokens.new_full([bsz, 1], self.word_eos_idx) # Z[Batch, Len=1] lm_probs: torch.Tensor = self.lm_decoder.get_normalized_probs( self.lm_decoder(w, incremental_state=incremental_state), @@ -110,9 +111,9 @@ def forward(self, else: # Not the first step cumsum_probs: torch.Tensor = self.get_incremental_state( - incremental_state, 'cumsum_probs', + incremental_state, "cumsum_probs", ) # R[Batch, 1, Vocab] - nodes: torch.Tensor = self.get_incremental_state(incremental_state, 'nodes') # Z_NodeId[Batch] + nodes: torch.Tensor = self.get_incremental_state(incremental_state, "nodes") # Z_NodeId[Batch] assert nodes.size(0) == bsz w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(1) # Z[Batch, Len=1] w[w < 0] = self.word_unk_idx @@ -139,8 +140,8 @@ def forward(self, all_children = self.tree.children[nodes, :] # Z[Batch, PossibleChildren] - self.set_incremental_state(incremental_state, 'cumsum_probs', cumsum_probs) - self.set_incremental_state(incremental_state, 'nodes', nodes) + self.set_incremental_state(incremental_state, "cumsum_probs", cumsum_probs) + self.set_incremental_state(incremental_state, "nodes", nodes) # Compute probabilities # initialize out_probs [Batch, 1, Vocab] @@ -161,7 +162,7 @@ def forward(self, # set transition probability to 1 for those whose node is out of the # tree, i.e. node is None (case 4 in Eqn. 15) - out_probs[nodes.eq(self.tree.none_id)] = 1. + out_probs[nodes.eq(self.tree.none_id)] = 1.0 else: # set out_probs to 0 out_probs = cumsum_probs.new_full([bsz, 1, self.subword_vocab_size], self.zero) @@ -226,15 +227,15 @@ def forward(self, def reorder_incremental_state(self, incremental_state, new_order): super().reorder_incremental_state(incremental_state, new_order) - cumsum_probs = self.get_incremental_state(incremental_state, 'cumsum_probs') + cumsum_probs = self.get_incremental_state(incremental_state, "cumsum_probs") if cumsum_probs is not None: new_cumsum_probs = cumsum_probs.index_select(0, new_order) - self.set_incremental_state(incremental_state, 'cumsum_probs', new_cumsum_probs) + self.set_incremental_state(incremental_state, "cumsum_probs", new_cumsum_probs) - nodes = self.get_incremental_state(incremental_state, 'nodes') + nodes = self.get_incremental_state(incremental_state, "nodes") if nodes is not None: new_nodes = nodes.index_select(0, new_order) - self.set_incremental_state(incremental_state, 'nodes', new_nodes) + self.set_incremental_state(incremental_state, "nodes", new_nodes) def get_normalized_probs(self, net_output, log_probs, sample=None): """Get normalized probabilities (or log probs) from a net's output.""" diff --git a/espresso/modules/__init__.py b/espresso/modules/__init__.py index 1e6b35acd..08e3bb3ca 100644 --- a/espresso/modules/__init__.py +++ b/espresso/modules/__init__.py @@ -7,6 +7,6 @@ __all__ = [ - 'BahdanauAttention', - 'LuongAttention', + "BahdanauAttention", + "LuongAttention", ] diff --git a/espresso/modules/speech_attention.py b/espresso/modules/speech_attention.py index fcba56e91..73b4132e6 100644 --- a/espresso/modules/speech_attention.py +++ b/espresso/modules/speech_attention.py @@ -54,8 +54,8 @@ def reset_parameters(self): self.value_proj.weight.data.uniform_(-0.1, 0.1) nn.init.uniform_(self.v, -0.1, 0.1) if self.normalize: - nn.init.constant_(self.b, 0.) - nn.init.constant_(self.g, math.sqrt(1. / self.embed_dim)) + nn.init.constant_(self.b, 0.0) + nn.init.constant_(self.g, math.sqrt(1.0 / self.embed_dim)) def forward(self, query, value, key_padding_mask=None, state=None): # projected_query: 1 x bsz x embed_dim @@ -71,9 +71,11 @@ def forward(self, query, value, key_padding_mask=None, state=None): attn_scores = self.v * torch.tanh(projected_query + key).sum(dim=2) if key_padding_mask is not None: - attn_scores = attn_scores.float().masked_fill_( - key_padding_mask, float('-inf'), - ).type_as(attn_scores) # FP16 support: cast to float and back + attn_scores = ( + attn_scores.float() + .masked_fill_(key_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz @@ -99,7 +101,7 @@ def __init__(self, query_dim, value_dim, embed_dim=None, scale=True): def reset_parameters(self): self.value_proj.weight.data.uniform_(-0.1, 0.1) if self.scale: - nn.init.constant_(self.g, 1.) + nn.init.constant_(self.g, 1.0) def forward(self, query, value, key_padding_mask=None, state=None): query = query.unsqueeze(1) # bsz x 1 x query_dim @@ -110,9 +112,11 @@ def forward(self, query, value, key_padding_mask=None, state=None): attn_scores = self.g * attn_scores if key_padding_mask is not None: - attn_scores = attn_scores.float().masked_fill_( - key_padding_mask, float('-inf'), - ).type_as(attn_scores) # FP16 support: cast to float and back + attn_scores = ( + attn_scores.float() + .masked_fill_(key_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz diff --git a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index e8e919860..c3f1141db 100644 --- a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from omegaconf import II from typing import List import torch.optim.lr_scheduler @@ -13,6 +12,7 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.lr_scheduler import register_lr_scheduler from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateau +from omegaconf import II, DictConfig @dataclass @@ -40,7 +40,7 @@ class ReduceLROnPlateauV2Config(FairseqDataclass): warmup_init_lr: float = field( default=-1, metadata={ - "help": "initial learning rate during warmup phase; default is args.lr" + "help": "initial learning rate during warmup phase; default is cfg.lr" }, ) final_lr_scale: float = field( @@ -52,25 +52,30 @@ class ReduceLROnPlateauV2Config(FairseqDataclass): metadata={"help": "start to reduce lr from the specified epoch"}, ) # TODO common vars at parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") + maximize_best_checkpoint_metric: bool = II("checkpoint.maximize_best_checkpoint_metric") @register_lr_scheduler("reduce_lr_on_plateau_v2", dataclass=ReduceLROnPlateauV2Config) class ReduceLROnPlateauV2(ReduceLROnPlateau): """Decay the LR by a factor every time the validation loss plateaus, starting - from the epoch specified as args.start_reduce_lr_epoch. + from the epoch specified as cfg.start_reduce_lr_epoch. We also support specifying a final lr which will be kept until the max number of epochs is reached. """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) + def __init__(self, cfg: DictConfig, fairseq_optimizer): + super().__init__(cfg, fairseq_optimizer) + self.cfg = cfg self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer.optimizer, patience=args.lr_patience, factor=args.lr_shrink, - mode="max" if args.maximize_best_checkpoint_metric else "min", - threshold=args.lr_threshold, min_lr=args.final_lr_scale * args.lr[0] + self.optimizer.optimizer, + patience=cfg.lr_patience, + factor=cfg.lr_shrink, + mode="max" if cfg.maximize_best_checkpoint_metric else "min", + threshold=cfg.lr_threshold, + min_lr=cfg.final_lr_scale * cfg.lr[0], ) @classmethod @@ -80,8 +85,8 @@ def add_args(cls, parser): gen_parser_from_dataclass(parser, dc()) def step(self, epoch, val_loss=None): - if epoch < self.args.start_reduce_lr_epoch: + if epoch < self.cfg.start_reduce_lr_epoch: self.lr_scheduler.last_epoch = epoch - self.optimizer.set_lr(self.args.lr[0]) + self.optimizer.set_lr(self.cfg.lr[0]) return self.optimizer.get_lr() return super().step(epoch, val_loss) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index a08f10120..5c982f464 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -14,15 +14,16 @@ import math import os import sys +from argparse import Namespace import numpy as np - import torch - from fairseq import checkpoint_utils, options, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.models import FairseqLanguageModel +from omegaconf import DictConfig from espresso.models.external_language_model import MultiLevelLanguageModel from espresso.models.tensorized_lookahead_language_model import TensorizedLookaheadLanguageModel @@ -30,17 +31,22 @@ from espresso.tools.utils import plot_attention, sequence_mask -def main(args): - assert args.path is not None, "--path required for recognition!" - assert not args.sampling or args.nbest == args.beam, \ - "--sampling requires --nbest to be equal to --beam" +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for recognition!" + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" - if args.results_path is not None: - os.makedirs(args.results_path, exist_ok=True) - output_path = os.path.join(args.results_path, "decode.log") + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) + output_path = os.path.join(cfg.common_eval.results_path, "decode.log") with open(output_path, "w", buffering=1, encoding="utf-8") as h: - return _main(args, h) - return _main(args, sys.stdout) + return _main(cfg, h) + return _main(cfg, sys.stdout) def get_symbols_to_strip_from_output(generator): @@ -50,7 +56,7 @@ def get_symbols_to_strip_from_output(generator): return {generator.eos, generator.pad} -def _main(args, output_file): +def _main(cfg, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -61,53 +67,53 @@ def _main(args, output_file): if output_file is not sys.stdout: # also print to stdout logger.addHandler(logging.StreamHandler(sys.stdout)) - print_options_meaning_changes(args, logger) + print_options_meaning_changes(cfg, logger) - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.max_tokens is None and args.batch_size is None: - args.max_tokens = 12000 - logger.info(args) + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset split - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) # Set dictionary dictionary = task.target_dictionary - overrides = ast.literal_eval(args.model_overrides) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.path), + utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) - if args.lm_path is not None: - overrides["data"] = args.data + if cfg.generation.lm_path is not None: + overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.lm_path), - arg_overrides=overrides, - task=None, + utils.split_paths(cfg.generation.lm_path), arg_overrides=overrides, task=None, ) except: - logger.warning(f"Failed to load language model! Please make sure that the language model dict is the same " - f"as target dict and is located in the data dir ({args.data})") + logger.warning( + f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({cfg.task.data})" + ) raise assert len(lms) == 1 or len(lms) == 2 # Multi-level LM expects two LMs @@ -122,61 +128,60 @@ def _main(args, output_file): if i > 0 and isinstance(lms[i - 1], FairseqLanguageModel): lms[i - 1] = MultiLevelLanguageModel( m, lms[i - 1], - subwordlm_weight=args.subwordlm_weight, - oov_penalty=args.oov_penalty, - open_vocab=not args.disable_open_vocab, + subwordlm_weight=cfg.generation.subwordlm_weight, + oov_penalty=cfg.generation.oov_penalty, + open_vocab=not cfg.generation.disable_open_vocab, ) del lms[i] logger.info("LM fusion with Multi-level LM") else: lms[i] = TensorizedLookaheadLanguageModel( m, dictionary, - oov_penalty=args.oov_penalty, - open_vocab=not args.disable_open_vocab, + oov_penalty=cfg.generation.oov_penalty, + open_vocab=not cfg.generation.disable_open_vocab, ) logger.info("LM fusion with Look-ahead Word LM") else: assert isinstance(m, FairseqLanguageModel) logger.info("LM fusion with Subword LM") - if args.lm_weight != 0.0: - logger.info("using LM fusion with lm-weight={:.2f}".format(args.lm_weight)) + if cfg.generation.lm_weight != 0.0: + logger.info("using LM fusion with lm-weight={:.2f}".format(cfg.generation.lm_weight)) # Optimize ensemble for generation for model in chain(models, lms): if model is None: continue - if args.fp16: + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load dataset (possibly sharded) itr = task.get_batch_iterator( - dataset=task.dataset(args.gen_subset), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), - *[model.max_positions() if hasattr(model, "encoder") - else (None, model.max_positions()) for model in models] + task.max_positions(), *[m.max_positions() for m in models] ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator - if args.match_source_len: + if cfg.generation.match_source_len: logger.warning( "The option match_source_len is not applicable to speech recognition. Ignoring it." ) @@ -184,18 +189,20 @@ def _main(args, output_file): extra_gen_cls_kwargs = { "lm_model": lms[0], - "lm_weight": args.lm_weight, - "eos_factor": args.eos_factor, + "lm_weight": cfg.generation.lm_weight, + "eos_factor": cfg.generation.eos_factor, } - args.score_reference = False # not applicable for ASR - temp_val = args.print_alignment - args.print_alignment = False # not applicable for ASR - generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs) - args.print_alignment = temp_val + cfg.generation.score_reference = False # not applicable for ASR + temp_val = cfg.generation.print_alignment + cfg.generation.print_alignment = False # not applicable for ASR + generator = task.build_generator( + models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + cfg.generation.print_alignment = temp_val # Handle tokenization and BPE - tokenizer = task.build_tokenizer(args) - bpe = task.build_bpe(args) + tokenizer = task.build_tokenizer(cfg.tokenizer) + bpe = task.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: @@ -204,8 +211,8 @@ def decode_fn(x): x = tokenizer.decode(x) return x - # Generate and compute WER - scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter) + scorer = wer.Scorer(dictionary, wer_output_filter=cfg.task.wer_output_filter) + num_sentences = 0 has_target = True wps_meter = TimeMeter() @@ -215,20 +222,26 @@ def decode_fn(x): continue prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample["target"][:, :args.prefix_size] + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() - hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) + hypos = task.inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) gen_timer.stop(num_generated_tokens) # obtain nonpad mask of encoder output to plot attentions - if args.print_alignment: + if cfg.generation.print_alignment: net_input = sample["net_input"] src_tokens = net_input["src_tokens"] output_lengths = models[0].encoder.output_lengths(net_input["src_lengths"]) @@ -241,19 +254,19 @@ def decode_fn(x): # Retrieve the original sentences if has_target: target_str = sample["target_raw_text"][i] - if not args.quiet: + if not cfg.common_eval.quiet: detok_target_str = decode_fn(target_str) print("T-{}\t{}".format(utt_id, detok_target_str), file=output_file) # Process top predictions - for j, hypo in enumerate(hypos[i][:args.nbest]): + for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): hypo_str = dictionary.string( hypo["tokens"].int().cpu(), bpe_symbol=None, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) # not removing bpe at this point detok_hypo_str = decode_fn(hypo_str) - if not args.quiet: + if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 print("H-{}\t{}\t{}".format(utt_id, detok_hypo_str, score), file=output_file) @@ -261,9 +274,9 @@ def decode_fn(x): if j == 0: # src_len x tgt_len attention = hypo["attention"][nonpad_idxs[i]].float().cpu() \ - if args.print_alignment and hypo["attention"] is not None else None - if args.print_alignment and attention is not None: - save_dir = os.path.join(args.results_path, "attn_plots") + if cfg.generation.print_alignment and hypo["attention"] is not None else None + if cfg.generation.print_alignment and attention is not None: + save_dir = os.path.join(cfg.common_eval.results_path, "attn_plots") os.makedirs(save_dir, exist_ok=True) plot_attention(attention, detok_hypo_str, utt_id, save_dir) scorer.add_prediction(utt_id, hypo_str) @@ -277,26 +290,26 @@ def decode_fn(x): logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) - if args.print_alignment: + if cfg.generation.print_alignment: logger.info("Saved attention plots in " + save_dir) if has_target: - scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids) + scorer.add_ordered_utt_list(task.datasets[cfg.dataset.gen_subset].tgt.utt_ids) fn = "decoded_char_results.txt" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_char_results()) logger.info("Decoded char results saved as " + f.name) fn = "decoded_results.txt" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_results()) logger.info("Decoded results saved as " + f.name) if has_target: - header = "Recognize {} with beam={}: ".format(args.gen_subset, args.beam) + header = "Recognize {} with beam={}: ".format(cfg.dataset.gen_subset, cfg.generation.beam) fn = "wer" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.wer())) logger.info(header + res) @@ -304,7 +317,7 @@ def decode_fn(x): logger.info("WER saved in " + f.name) fn = "cer" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( *(scorer.cer())) logger.info(" " * len(header) + res) @@ -312,34 +325,23 @@ def decode_fn(x): logger.info("CER saved in " + f.name) fn = "aligned_results.txt" - with open(os.path.join(args.results_path, fn), "w", encoding="utf-8") as f: + with open(os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8") as f: f.write(scorer.print_aligned_results()) logger.info("Aligned results saved as " + f.name) return scorer -def print_options_meaning_changes(args, logger): +def print_options_meaning_changes(cfg, logger): """Options that have different meanings than those in the translation task are explained here. """ logger.info("--max-tokens is the maximum number of input frames in a batch") - if args.print_alignment: + if cfg.generation.print_alignment: logger.info("--print-alignment has been set to plot attentions") def cli_main(): parser = options.get_generation_parser(default_task="speech_recognition_espresso") - parser.add_argument("--eos-factor", default=None, type=float, metavar="F", - help="only consider emitting EOS if its score is no less " - "than the specified factor of the best candidate score") - parser.add_argument("--subwordlm-weight", default=0.8, type=float, metavar="W", - help="subword LM weight relative to word LM. Only relevant " - "to MultiLevelLanguageModel as an external LM") - parser.add_argument("--oov-penalty", default=1e-4, type=float, - help="oov penalty with the pretrained external LM") - parser.add_argument("--disable-open-vocab", action="store_true", - help="whether open vocabulary mode is enabled with the " - "pretrained external LM") args = options.parse_args_and_arch(parser) assert args.results_path is not None, "please specify --results-path" main(args) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 1c2b2b4f3..9391ba72e 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -8,13 +8,16 @@ Train a new model on one or across multiple GPUs. """ +import argparse import logging import math import os import sys +from typing import Dict, Optional, Any, List, Tuple, Callable import numpy as np import torch + from fairseq import ( checkpoint_utils, distributed_utils, @@ -24,8 +27,10 @@ utils, ) from fairseq.data import iterators +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from omegaconf import DictConfig from fairseq.trainer import Trainer @@ -38,90 +43,89 @@ logger = logging.getLogger("espresso.speech_train") -def main(args): - utils.import_user_module(args) +def main(cfg: DictConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - assert ( - args.max_tokens is not None or args.batch_size is not None - ), "Must specify batch size either with --max-tokens or --batch-size" + utils.import_user_module(cfg.common) + assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ + "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - if distributed_utils.is_master(args): - checkpoint_utils.verify_checkpoint_directory(args.save_dir) + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args - logger.info(args) + logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(args) - + task = tasks.setup_task(cfg.task) + # Handle tokenization and BPE + task.build_tokenizer(cfg.tokenizer) + task.build_bpe(cfg.bpe) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in args.valid_subset.split(","): + for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion - model = task.build_model(args) - criterion = task.build_criterion(args) + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) - logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) - logger.info( - "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) - ) + logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) + logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) logger.info( - "num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - ) + "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) ) + logger.info("num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) # (optionally) Configure quantization - if args.quantization_config_path is not None: + if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( - config_path=args.quantization_config_path, - max_epoch=args.max_epoch, - max_update=args.max_update, + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer - if args.model_parallel_size == 1: - trainer = Trainer(args, task, model, criterion, quantizer) + if cfg.common.model_parallel_size == 1: + trainer = Trainer(cfg, task, model, criterion, quantizer) else: - trainer = MegatronTrainer(args, task, model, criterion) + trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( - "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) - ) - logger.info( - "max input frames per GPU = {} and max sentences per GPU = {}".format( - args.max_tokens, args.batch_size - ) - ) + logger.info("training on {} devices (GPUs/TPUs)".format(cfg.distributed_training.distributed_world_size)) + logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( - args, + cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) - # Train until the learning rate gets too small - max_epoch = args.max_epoch or math.inf + max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - - while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while ( + lr > cfg.optimization.min_lr + and epoch_itr.next_epoch_idx <= max_epoch + ): # train for one epoch - valid_losses, should_stop = train(args, trainer, task, epoch_itr) + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break @@ -139,15 +143,15 @@ def main(args): logger.info("done training in {:.1f} seconds".format(train_meter.sum)) -def should_stop_early(args, valid_loss): +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch if valid_loss is None: return False - if args.patience <= 0: + if cfg.checkpoint.patience <= 0: return False def is_better(a, b): - return a > b if args.maximize_best_checkpoint_metric else a < b + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): @@ -156,42 +160,41 @@ def is_better(a, b): return False else: should_stop_early.num_runs += 1 - if should_stop_early.num_runs >= args.patience: - logger.info( - "early stop since valid performance hasn't improved for last {} runs".format( - args.patience - ) - ) + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info("early stop since valid performance hasn't improved for last {} runs".format(cfg.checkpoint.patience)) return True else: return False @metrics.aggregate("train") -def train(args, trainer, task, epoch_itr): +def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( - fix_batches_to_gpus=args.fix_batches_to_gpus, - shuffle=(epoch_itr.next_epoch_idx > args.curriculum), + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = ( - args.update_freq[epoch_itr.epoch - 1] - if epoch_itr.epoch <= len(args.update_freq) - else args.update_freq[-1] + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(args, "tpu", False): + if getattr(cfg.common, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) @@ -199,8 +202,7 @@ def train(args, trainer, task, epoch_itr): if hasattr(trainer.criterion, "set_epoch"): trainer.criterion.set_epoch(epoch_itr.epoch) - valid_losses = [None] - valid_subsets = args.valid_subset.split(",") + valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): @@ -212,7 +214,7 @@ def train(args, trainer, task, epoch_itr): if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() - if num_updates % args.log_interval == 0: + if num_updates % cfg.common.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) @@ -220,13 +222,13 @@ def train(args, trainer, task, epoch_itr): # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") - # update the state prior stored in the model for cross-entropy training + # update the state prior stored in the model for cross-entropy training of hybrid systems if hasattr(task, "update_state_prior"): task.update_state_prior(trainer.get_model()) end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( - args, trainer, task, epoch_itr, valid_subsets, end_of_epoch + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: @@ -242,84 +244,87 @@ def train(args, trainer, task, epoch_itr): return valid_losses, should_stop -def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): +def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() - max_update = args.max_update or math.inf + max_update = cfg.optimization.max_update or math.inf do_save = ( - (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or num_updates >= max_update or ( - args.save_interval_updates > 0 + cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates >= args.validate_after_updates + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or num_updates >= max_update or ( - args.validate_interval_updates > 0 + cfg.dataset.validate_interval_updates > 0 and num_updates > 0 - and num_updates % args.validate_interval_updates == 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 ) - ) and not args.disable_validation + ) and not cfg.dataset.disable_validation # Validate valid_losses = [None] if do_validate: - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) # Stopping conditions should_stop = ( - should_stop_early(args, valid_losses[0]) + should_stop_early(cfg, valid_losses[0]) or num_updates >= max_update or ( - args.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours + cfg.optimization.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop -def get_training_stats(stats): +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats -def validate(args, trainer, task, epoch_itr, subsets): +def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" - if args.fixed_validation_seed is not None: + if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation - utils.set_torch_seed(args.fixed_validation_seed) + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: - logger.info('begin validation on "{}" subset'.format(subset)) + logger.info("begin validation on '{}' subset".format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) - if getattr(args, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics @@ -329,20 +334,20 @@ def validate(args, trainer, task, epoch_itr, subsets): trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[args.best_checkpoint_metric]) + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses -def get_valid_stats(args, trainer, stats): +def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): - key = "best_{0}".format(args.best_checkpoint_metric) - best_function = max if args.maximize_best_checkpoint_metric else min + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] ) return stats @@ -354,16 +359,19 @@ def print_options_meaning_changes(args): logger.info("--max-tokens is the maximum number of input frames in a batch") -def cli_main(modify_parser=None): +def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) print_options_meaning_changes(args) + + cfg = convert_namespace_to_omegaconf(args) + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) else: - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) if __name__ == "__main__": diff --git a/espresso/tasks/language_modeling_for_asr.py b/espresso/tasks/language_modeling_for_asr.py index 4af8008b1..541735910 100644 --- a/espresso/tasks/language_modeling_for_asr.py +++ b/espresso/tasks/language_modeling_for_asr.py @@ -6,6 +6,7 @@ import logging import os from dataclasses import dataclass, field +from typing import Optional import torch @@ -22,7 +23,7 @@ @dataclass class LanguageModelingForASRConfig(LanguageModelingConfig): - dict: str = field(default=None, metadata={"help": "path to the dictionary"}) + dict: Optional[str] = field(default=None, metadata={"help": "path to the dictionary"}) @register_task("language_modeling_for_asr", dataclass=LanguageModelingForASRConfig) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 5df459456..58ed02b1c 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -8,13 +8,17 @@ import json import logging import os +from dataclasses import dataclass, field +from typing import Optional import torch from fairseq import utils from fairseq.data import BaseWrapperDataset, ConcatDataset +from fairseq.dataclass import FairseqDataclass from fairseq.logging import metrics from fairseq.tasks import FairseqTask, register_task +from omegaconf import II, DictConfig from espresso.data import ( AsrDataset, @@ -27,12 +31,75 @@ logger = logging.getLogger(__name__) +@dataclass +class SpeechRecognitionEspressoConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + dict: Optional[str] = field(default=None, metadata={"help": "path to the dictionary"}) + non_lang_syms: Optional[str] = field( + default=None, + metadata={ + "help": "path to a file listing non-linguistic symbols, e.g., " + "etc. One entry per line. To be filtered out when calculating WER/CER" + }, + ) + word_dict: Optional[str] = field( + default=None, + metadata={"help": "path to the word dictionary. Only relevant for decoding"}, + ) + wer_output_filter: Optional[str] = field( + default=None, + metadata={"help": "path to wer_output_filter file for WER evaluation"}, + ) + max_source_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + upsample_primary: int = field( + default=1, metadata={"help": "amount to upsample primary dataset"}, + ) + num_batch_buckets: Optional[int] = field( + default=0, + metadata={ + "help": "if >0, then bucket source and target lengths into N " + "buckets and pad accordingly; this is useful on TPUs " + "to minimize the number of compilations" + }, + ) + feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"}) + specaugment_config: Optional[str] = field( + default=None, + metadata={ + "help": "SpecAugment config string. If not None and not empty, " + "then apply SpecAugment. Should be an evaluatable expression of " + "a python dict. See speech_tools.specaug_interpolate.specaug() for " + "all allowed arguments. Argments not appearing in this string " + "will take on their default values" + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + train_subset: str = II("dataset.train_subset") + gen_subset: str = II("dataset.gen_subset") + required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") + + def get_asr_dataset_from_json( - data_path, split, tgt_dict, - combine, upsample_primary, - num_buckets=0, shuffle=True, + data_path, + split, + tgt_dict, + combine, + upsample_primary, + num_buckets=0, + shuffle=True, pad_to_multiple=1, - seed=1, specaugment_config=None, + seed=1, + specaugment_config=None, ): """ Parse data json and create dataset. @@ -58,7 +125,9 @@ def get_asr_dataset_from_json( if k > 0: break else: - raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + raise FileNotFoundError( + "Dataset not found: {}".format(data_json_path) + ) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) @@ -97,8 +166,9 @@ def get_asr_dataset_from_json( tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None else: for i in range(1, len(src_datasets)): - assert feat_dim == src_datasets[i].feat_dim, \ - "feature dimension does not match across multiple json files" + assert ( + feat_dim == src_datasets[i].feat_dim + ), "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) @@ -109,8 +179,10 @@ def get_asr_dataset_from_json( tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None return AsrDataset( - src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset_sizes, + src_dataset, + src_dataset.sizes, + tgt_dataset, + tgt_dataset_sizes, tgt_dict, left_pad_source=False, left_pad_target=False, @@ -120,7 +192,7 @@ def get_asr_dataset_from_json( ) -@register_task("speech_recognition_espresso") +@register_task("speech_recognition_espresso", dataclass=SpeechRecognitionEspressoConfig) class SpeechRecognitionEspressoTask(FairseqTask): """ Transcribe from speech (source) to token text (target). @@ -144,40 +216,6 @@ class SpeechRecognitionEspressoTask(FairseqTask): :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - parser.add_argument("data", help="path to data directory") - parser.add_argument("--dict", default=None, type=str, - help="path to the dictionary") - parser.add_argument("--non-lang-syms", default=None, type=str, - help="path to a file listing non-linguistic symbols, e.g., " - "etc. One entry per line. To be filtered out when calculating WER/CER.") - parser.add_argument("--word-dict", default=None, type=str, - help="path to the word dictionary. Only relevant for decoding") - parser.add_argument("--wer-output-filter", default=None, type=str, - help="path to wer_output_filter file for WER evaluation") - parser.add_argument("--max-source-positions", default=1024, type=int, metavar="N", - help="max number of frames in the source sequence") - parser.add_argument("--max-target-positions", default=1024, type=int, metavar="N", - help="max number of tokens in the target sequence") - parser.add_argument("--upsample-primary", default=1, type=int, - help="amount to upsample primary dataset") - parser.add_argument("--num-batch-buckets", default=0, type=int, metavar="N", - help="if >0, then bucket source and target lengths into N " - "buckets and pad accordingly; this is useful on TPUs " - "to minimize the number of compilations") - parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", - help="feature input channels") - parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", - help="SpecAugment config string. If not None and not empty, " - "then apply SpecAugment. Should be an evaluatable expression of " - "a python dict. See speech_tools.specaug_interpolate.specaug() for " - "all allowed arguments. Argments not appearing in this string " - "will take on their default values") - # fmt: off - @classmethod def load_dictionary(cls, filename, non_lang_syms=None): """Load the dictionary from the filename @@ -195,14 +233,12 @@ def build_dictionary( """ raise NotImplementedError - def __init__(self, args, tgt_dict, word_dict=None): - super().__init__(args) + def __init__(self, cfg: DictConfig, tgt_dict, word_dict=None): + super().__init__(cfg) self.tgt_dict = tgt_dict - self.tgt_dict.build_tokenizer(args) - self.tgt_dict.build_bpe(args) self.word_dict = word_dict - self.feat_in_channels = args.feat_in_channels - self.specaugment_config = args.specaugment_config + self.feat_in_channels = cfg.feat_in_channels + self.specaugment_config = cfg.specaugment_config torch.backends.cudnn.deterministic = True # Compansate for the removel of :func:`torch.rand()` from # :func:`fairseq.distributed_utils.distributed_init()` by fairseq, @@ -210,23 +246,23 @@ def __init__(self, args, tgt_dict, word_dict=None): torch.rand(1) @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): parsed command-line arguments """ # load dictionaries - dict_path = os.path.join(args.data, "dict.txt") if args.dict is None else args.dict - tgt_dict = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) + dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict + tgt_dict = cls.load_dictionary(dict_path, non_lang_syms=cfg.non_lang_syms) logger.info("dictionary: {} types".format(len(tgt_dict))) - if args.word_dict is not None: - word_dict = cls.load_dictionary(args.word_dict) + if cfg.word_dict is not None: + word_dict = cls.load_dictionary(cfg.word_dict) logger.info("word dictionary: {} types".format(len(word_dict))) - return cls(args, tgt_dict, word_dict) + return cls(cfg, tgt_dict, word_dict) else: - return cls(args, tgt_dict) + return cls(cfg, tgt_dict) def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -234,21 +270,23 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 - if split != getattr(self.args, "train_subset", None): + if split != self.cfg.train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] self.datasets[split] = get_asr_dataset_from_json( - data_path, split, self.tgt_dict, + data_path, + split, + self.tgt_dict, combine=combine, - upsample_primary=self.args.upsample_primary, - num_buckets=self.args.num_batch_buckets, - shuffle=(split != getattr(self.args, "gen_subset", None)), - pad_to_multiple=self.args.required_seq_len_multiple, - seed=self.args.seed, + upsample_primary=self.cfg.upsample_primary, + num_buckets=self.cfg.num_batch_buckets, + shuffle=(split != self.cfg.gen_subset), + pad_to_multiple=self.cfg.required_seq_len_multiple, + seed=self.cfg.seed, specaugment_config=self.specaugment_config, ) @@ -271,13 +309,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): return AsrDataset( - src_tokens, src_lengths, dictionary=self.target_dictionary, constraints=constraints, + src_tokens, + src_lengths, + dictionary=self.target_dictionary, + constraints=constraints, ) - def build_model(self, args): - model = super().build_model(args) + def build_model(self, cfg: DictConfig): + model = super().build_model(cfg) # build the greedy decoder for validation with WER from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder + self.decoder_for_validation = SimpleGreedyDecoder( [model], self.target_dictionary, for_validation=True, ) @@ -304,13 +346,25 @@ def reduce_metrics(self, logging_outputs, criterion): def max_positions(self): """Return the max sentence length allowed by the task.""" - return (self.args.max_source_positions, self.args.max_target_positions) + return (self.cfg.max_source_positions, self.cfg.max_target_positions) @property def target_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" return self.tgt_dict + def build_tokenizer(self, cfg: DictConfig): + """Build the pre-tokenizer for this task.""" + self.tgt_dict.build_tokenizer(cfg) + # the instance is built within self.tgt_dict + return self.tgt_dict.tokenizer + + def build_bpe(self, cfg: DictConfig): + """Build the tokenizer for this task.""" + self.tgt_dict.build_bpe(cfg) + # the instance is built within self.tgt_dict + return self.tgt_dict.bpe + @property def word_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" @@ -319,7 +373,7 @@ def word_dictionary(self): def _inference_with_wer(self, decoder, sample, model): from espresso.tools import wer - scorer = wer.Scorer(self.target_dictionary, wer_output_filter=self.args.wer_output_filter) + scorer = wer.Scorer(self.target_dictionary, wer_output_filter=self.cfg.wer_output_filter) tokens, lprobs, _ = decoder.decode([model], sample) pred = tokens[:, 1:].data.cpu() # bsz x len target = sample["target"] diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index 6eedeaa84..a10b07ea5 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -8,13 +8,17 @@ import json import logging import os +from dataclasses import dataclass, field +from typing import Optional import torch from fairseq import utils from fairseq.data import BaseWrapperDataset, ConcatDataset - +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig from fairseq.tasks import FairseqTask, register_task +from omegaconf import II, DictConfig from espresso.data import ( AliScpCachedDataset, @@ -35,14 +39,129 @@ logger = logging.getLogger(__name__) +@dataclass +class SpeechRecognitionHybridConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + dict: Optional[str] = field(default=None, metadata={"help": "path to the dictionary"}) + non_lang_syms: Optional[str] = field( + default=None, + metadata={ + "help": "path to a file listing non-linguistic symbols, e.g., " + "etc. One entry per line. To be filtered out when calculating WER/CER" + }, + ) + wer_output_filter: Optional[str] = field( + default=None, + metadata={"help": "path to wer_output_filter file for WER evaluation"}, + ) + max_source_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: Optional[int] = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + upsample_primary: int = field( + default=1, metadata={"help": "amount to upsample primary dataset"}, + ) + num_batch_buckets: Optional[int] = field( + default=0, + metadata={ + "help": "if >0, then bucket source and target lengths into N " + "buckets and pad accordingly; this is useful on TPUs " + "to minimize the number of compilations" + }, + ) + feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"}) + specaugment_config: Optional[str] = field( + default=None, + metadata={ + "help": "SpecAugment config string. If not None and not empty, " + "then apply SpecAugment. Should be an evaluatable expression of " + "a python dict. See speech_tools.specaug_interpolate.specaug() for " + "all allowed arguments. Argments not appearing in this string " + "will take on their default values" + }, + ) + num_targets: int = field( + default=3000, + metadata={"help": "number of targets for training (e.g., num pdf-ids)"}, + ) + initial_state_prior_file: Optional[str] = field( + default=None, + metadata={ + "help": "path to the file containing initial state prior. Only relevant " + "with cross-entropy training" + }, + ) + state_prior_update_interval: Optional[int] = field( + default=None, + metadata={ + "help": "state prior estimate will be updated every this number of updates " + "during training. If None, then use the initial value estimated from the " + "alignments. Only relevant with cross-entropy training" + }, + ) + state_prior_update_smoothing: Optional[float] = field( + default=0.1, + metadata={ + "help": "smoothing factor while updating state prior estimate. Only " + "relevant with cross-entropy training" + }, + ) + chunk_width: Optional[int] = field( + default=None, + metadata={ + "help": "chunk width for train/test data. Only relevant with chunk-wise " + "training (including both cross-entropy and Lattice-free MMI). " + "Do utterance-wise training/test if not specified" + }, + ) + chunk_left_context: Optional[int] = field( + default=0, + metadata={"help": "number of frames appended to the left of a chunk"}, + ) + chunk_right_context: Optional[int] = field( + default=0, + metadata={"help": "number of frames appended to the right of a chunk"}, + ) + label_delay: Optional[int] = field( + default=0, + metadata={ + "help": "offet of alignments as prediction labels. Maybe useful " + "in archs such as asymmetric convolution, unidirectional LSTM, etc. " + "It can be negative. Only relevant with chunk-wise cross-entropy training" + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + train_subset: str = II("dataset.train_subset") + valid_subset: str = II("dataset.valid_subset") + gen_subset: str = II("dataset.gen_subset") + required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") + criterion_name: str = II("criterion._name") + max_epoch: int = II("optimization.max_epoch") # to determine whether in trainig stage + + def get_asr_dataset_from_json( - data_path, split, dictionary, - combine, upsample_primary, - num_buckets=0, shuffle=True, + data_path, + split, + dictionary, + combine, + upsample_primary, + num_buckets=0, + shuffle=True, pad_to_multiple=1, lf_mmi=True, - seed=1, specaugment_config=None, - chunk_width=None, chunk_left_context=None, chunk_right_context=None, label_delay=0, + seed=1, + specaugment_config=None, + chunk_width=None, + chunk_left_context=None, + chunk_right_context=None, + label_delay=0, ): """ Parse data json and create dataset. @@ -72,7 +191,9 @@ def get_asr_dataset_from_json( if k > 0: break else: - raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + raise FileNotFoundError( + "Dataset not found: {}".format(data_json_path) + ) with open(data_json_path, "rb") as f: loaded_json = json.load(f, object_pairs_hook=OrderedDict) @@ -103,9 +224,12 @@ def get_asr_dataset_from_json( else: # cross-entropy if len(alignments) > 0: assert len(utt_ids) == len(alignments) - tgt_datasets.append(AliScpCachedDataset( - utt_ids, alignments, utt2num_frames=utt2num_frames, ordered_prefetch=True - )) + tgt_datasets.append( + AliScpCachedDataset( + utt_ids, alignments, utt2num_frames=utt2num_frames, + ordered_prefetch=True, + ) + ) if len(text) > 0: assert len(utt_ids) == len(text) @@ -127,8 +251,9 @@ def get_asr_dataset_from_json( text_dataset = text_datasets[0] if len(text_datasets) > 0 else None else: for i in range(1, len(src_datasets)): - assert feat_dim == src_datasets[i].feat_dim, \ - "feature dimension does not match across multiple json files" + assert ( + feat_dim == src_datasets[i].feat_dim + ), "feature dimension does not match across multiple json files" sample_ratios = [1] * len(src_datasets) sample_ratios[0] = upsample_primary src_dataset = ConcatDataset(src_datasets, sample_ratios) @@ -144,8 +269,10 @@ def get_asr_dataset_from_json( tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None if lf_mmi: return AsrChainDataset( - src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset_sizes, + src_dataset, + src_dataset.sizes, + tgt_dataset, + tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, @@ -153,19 +280,24 @@ def get_asr_dataset_from_json( ) else: return AsrXentDataset( - src_dataset, src_dataset.sizes, - tgt_dataset, tgt_dataset_sizes, + src_dataset, + src_dataset.sizes, + tgt_dataset, + tgt_dataset_sizes, text=text_dataset, num_buckets=num_buckets, shuffle=shuffle, pad_to_multiple=pad_to_multiple, - seed=seed, chunk_width=chunk_width, - chunk_left_context=chunk_left_context, chunk_right_context=chunk_right_context, - label_delay=label_delay, random_chunking=(split == "train" and chunk_width is not None), + seed=seed, + chunk_width=chunk_width, + chunk_left_context=chunk_left_context, + chunk_right_context=chunk_right_context, + label_delay=label_delay, + random_chunking=(split == "train" and chunk_width is not None), ) -@register_task("speech_recognition_hybrid") +@register_task("speech_recognition_hybrid", dataclass=SpeechRecognitionHybridConfig) class SpeechRecognitionHybridTask(FairseqTask): """ Hybrid speech recognition with lattice-free MMI or cross-entropy loss. @@ -192,64 +324,6 @@ class SpeechRecognitionHybridTask(FairseqTask): :prog: """ - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - # fmt: off - parser.add_argument("data", help="path to data directory") - parser.add_argument("--dict", default=None, type=str, - help="path to the dictionary") - parser.add_argument("--non-lang-syms", default=None, type=str, - help="path to a file listing non-linguistic symbols, e.g., " - "etc. One entry per line. To be filtered out when calculating WER/CER.") - parser.add_argument("--wer-output-filter", default=None, type=str, - help="path to wer_output_filter file for WER evaluation") - parser.add_argument("--max-source-positions", default=1024, type=int, metavar="N", - help="max number of frames in the source sequence") - parser.add_argument("--max-target-positions", default=1024, type=int, metavar="N", - help="max number of tokens in the target sequence") - parser.add_argument("--upsample-primary", default=1, type=int, - help="amount to upsample primary dataset") - parser.add_argument("--num-batch-buckets", default=0, type=int, metavar="N", - help="if >0, then bucket source and target lengths into N " - "buckets and pad accordingly; this is useful on TPUs " - "to minimize the number of compilations") - parser.add_argument("--feat-in-channels", default=1, type=int, metavar="N", - help="feature input channels") - parser.add_argument("--specaugment-config", default=None, type=str, metavar="EXPR", - help="SpecAugment config string. If not None and not empty, " - "then apply SpecAugment. Should be an evaluatable expression of " - "a python dict. See speech_tools.specaug_interpolate.specaug() for " - "all allowed arguments. Argments not appearing in this string " - "will take on their default values") - - parser.add_argument("--num-targets", type=int, metavar="N", - help="number of targets for training (e.g., num pdf-ids)") - parser.add_argument("--initial-state-prior-file", default=None, type=str, metavar="FILE", - help="path to the file containing initial state prior. Only relevant " - "with cross-entropy training") - parser.add_argument("--state-prior-update-interval", default=None, type=int, metavar="N", - help="state prior estimate will be updated every this " - "number of updates during training. If None, then use " - "the initial value estimated from the alignments. Only relevant with " - "cross-entropy training") - parser.add_argument("--state-prior-update-smoothing", default=0.1, type=float, metavar="D", - help="smoothing factor while updating state prior estimate. Only " - "relevant with cross-entropy training") - parser.add_argument("--chunk-width", default=None, type=int, metavar="D", - help="chunk width for train/test data. Only relevant with chunk-wise " - "training (including both cross-entropy and Lattice-free MMI). " - "Do utterance-wise training/test if not specified") - parser.add_argument("--chunk-left-context", default=0, type=int, metavar="D", - help="number of frames appended to the left of a chunk") - parser.add_argument("--chunk-right-context", default=0, type=int, metavar="D", - help="number of frames appended to the right of a chunk") - parser.add_argument("--label-delay", default=0, type=int, metavar="D", - help="offet of alignments as prediction labels. Maybe useful " - "in archs such as asymmetric convolution, unidirectional LSTM, etc. " - "It can be negative. Only relevant with chunk-wise cross-entropy training") - # fmt: off - @classmethod def load_dictionary(cls, filename, non_lang_syms=None): """Load the dictionary from the filename @@ -265,51 +339,55 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding """ raise NotImplementedError - def __init__(self, args, dictionary): - super().__init__(args) + def __init__(self, cfg: DictConfig, dictionary): + super().__init__(cfg) self.dictionary = dictionary - self.feat_in_channels = args.feat_in_channels - self.specaugment_config = args.specaugment_config - self.num_targets = args.num_targets - self.training_stage = hasattr(args, "valid_subset") + self.feat_in_channels = cfg.feat_in_channels + self.specaugment_config = cfg.specaugment_config + self.num_targets = cfg.num_targets + self.training_stage = (cfg.max_epoch > 0) # a hack # the following attributes are related to state_prior estimate self.initial_state_prior = None - if args.initial_state_prior_file is not None: # only relevant for Xent training, used in models - self.initial_state_prior = kaldi_io.read_vec_flt(args.initial_state_prior_file) + if cfg.initial_state_prior_file is not None: # only relevant for Xent training, used in models + self.initial_state_prior = kaldi_io.read_vec_flt(cfg.initial_state_prior_file) self.initial_state_prior = torch.from_numpy(self.initial_state_prior) - assert self.initial_state_prior.size(0) == self.num_targets, \ - "length of initial_state_prior ({}) != num_targets ({})".format( - self.initial_state_prior.size(0), self.num_targets - ) - self.state_prior_update_interval = args.state_prior_update_interval + assert ( + self.initial_state_prior.size(0) == self.num_targets + ), "length of initial_state_prior ({}) != num_targets ({})".format( + self.initial_state_prior.size(0), self.num_targets + ) + self.state_prior_update_interval = cfg.state_prior_update_interval if self.state_prior_update_interval is None and self.initial_state_prior is not None: logger.info("state prior will not be updated during training") - self.state_prior_update_smoothing = args.state_prior_update_smoothing + self.state_prior_update_smoothing = cfg.state_prior_update_smoothing self.averaged_state_post = None # state poterior will be saved here before commited as new state prior # the following 4 options are for chunk-wise training/test (including Xent and LF-MMI) - self.chunk_width = args.chunk_width - self.chunk_left_context = args.chunk_left_context - self.chunk_right_context = args.chunk_right_context - self.label_delay = args.label_delay # only for chunk-wise Xent training + self.chunk_width = cfg.chunk_width + self.chunk_left_context = cfg.chunk_left_context + self.chunk_right_context = cfg.chunk_right_context + self.label_delay = cfg.label_delay # only for chunk-wise Xent training torch.backends.cudnn.deterministic = True @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): parsed command-line arguments """ # load dictionaries - dict_path = args.dict - dictionary = cls.load_dictionary(dict_path, non_lang_syms=args.non_lang_syms) if \ - dict_path is not None else None + dict_path = cfg.dict + dictionary = ( + cls.load_dictionary(dict_path, non_lang_syms=cfg.non_lang_syms) + if dict_path is not None + else None + ) if dictionary is not None: logger.info("dictionary: {} types".format(len(dictionary))) - return cls(args, dictionary) + return cls(cfg, dictionary) def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -317,24 +395,30 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 - if split != getattr(self.args, "train_subset", None): + if split != self.cfg.train_subset: # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] self.datasets[split] = get_asr_dataset_from_json( - data_path, split, self.dictionary, + data_path, + split, + self.dictionary, combine=combine, - upsample_primary=self.args.upsample_primary, - num_buckets=self.args.num_batch_buckets, - shuffle=(split != getattr(self.args, "gen_subset", None)), - pad_to_multiple=self.args.required_seq_len_multiple, - lf_mmi=(self.args.criterion == "lattice_free_mmi"), - seed=self.args.seed, specaugment_config=self.specaugment_config, - chunk_width=None if self.training_stage and split in self.args.valid_subset.split(",") else self.chunk_width, - chunk_left_context=self.chunk_left_context, chunk_right_context=self.chunk_right_context, + upsample_primary=self.cfg.upsample_primary, + num_buckets=self.cfg.num_batch_buckets, + shuffle=(split != self.cfg.gen_subset), + pad_to_multiple=self.cfg.required_seq_len_multiple, + lf_mmi=(self.cfg.criterion_name == "lattice_free_mmi"), + seed=self.cfg.seed, + specaugment_config=self.specaugment_config, + chunk_width=None if self.training_stage + and split in self.cfg.valid_subset.split(",") + else self.chunk_width, + chunk_left_context=self.chunk_left_context, + chunk_right_context=self.chunk_right_context, label_delay=self.label_delay, ) @@ -346,14 +430,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): else: self.feat_dim = src_dataset.feat_dim - def build_generator(self, models, args): - if args.score_reference: - args.score_reference = False + def build_generator(self, models, cfg: GenerationConfig): + if cfg.score_reference: + cfg.score_reference = False logger.warning( "--score-reference is not applicable to speech recognition, ignoring it." ) from espresso.tools.generate_log_probs_for_decoding import GenerateLogProbsForDecoding - apply_log_softmax = getattr(args, "apply_log_softmax", False) + + apply_log_softmax = getattr(cfg, "apply_log_softmax", False) return GenerateLogProbsForDecoding(models, apply_log_softmax=apply_log_softmax) def build_dataset_for_inference(self, src_tokens, src_lengths): @@ -387,7 +472,7 @@ def update_state_prior(self, model): def max_positions(self): """Return the max sentence length allowed by the task.""" - return (self.args.max_source_positions, self.args.max_target_positions) + return (self.cfg.max_source_positions, self.cfg.max_target_positions) @property def target_dictionary(self): diff --git a/espresso/tools/compute_wer.py b/espresso/tools/compute_wer.py index 8555e0995..7a56fed33 100755 --- a/espresso/tools/compute_wer.py +++ b/espresso/tools/compute_wer.py @@ -14,27 +14,27 @@ logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=sys.stderr, ) -logger = logging.getLogger('espresso.tools.compute_wer') +logger = logging.getLogger("espresso.tools.compute_wer") def get_parser(): parser = argparse.ArgumentParser( - description='Compute WER from text') + description="Compute WER from text") # fmt: off - parser.add_argument('--non-lang-syms', default=None, type=str, - help='path to a file listing non-linguistic symbols, ' - 'e.g., etc. One entry per line.') - parser.add_argument('--wer-output-filter', default=None, type=str, - help='path to wer_output_filter file for WER evaluation') - parser.add_argument('ref_text', type=str, - help='path to the reference text file') - parser.add_argument('hyp_text', type=str, - help='path to the hypothesis text file') + parser.add_argument("--non-lang-syms", default=None, type=str, + help="path to a file listing non-linguistic symbols, " + "e.g., etc. One entry per line.") + parser.add_argument("--wer-output-filter", default=None, type=str, + help="path to wer_output_filter file for WER evaluation") + parser.add_argument("ref_text", type=str, + help="path to the reference text file") + parser.add_argument("hyp_text", type=str, + help="path to the hypothesis text file") # fmt: on @@ -44,36 +44,36 @@ def get_parser(): def main(args): non_lang_syms = [] if args.non_lang_syms is not None: - with open(args.non_lang_syms, 'r', encoding='utf-8') as f: + with open(args.non_lang_syms, "r", encoding="utf-8") as f: non_lang_syms = [x.rstrip() for x in f.readlines()] word_filters = [] if args.wer_output_filter is not None: - with open(args.wer_output_filter, 'r', encoding='utf-8') as f: + with open(args.wer_output_filter, "r", encoding="utf-8") as f: for line in f: line = line.strip() - if line.startswith('#!') or line == '': + if line.startswith("#!") or line == "": continue - elif line.startswith('s/'): - m = re.match(r's/(\S+)/(\w*)/g', line) + elif line.startswith("s/"): + m = re.match(r"s/(\S+)/(\w*)/g", line) assert m is not None word_filters.append([m.group(1), m.group(2)]) - elif line.startswith('s:'): - m = re.match(r's:(\S+):(\w*):g', line) + elif line.startswith("s:"): + m = re.match(r"s:(\S+):(\w*):g", line) assert m is not None word_filters.append([m.group(1), m.group(2)]) else: - logger.warning('Unsupported pattern: "{}". Ignoring it.'.format(line)) + logger.warning("Unsupported pattern: '{}'. Ignoring it.".format(line)) refs = {} - with open(args.ref_text, 'r', encoding='utf-8') as f: + with open(args.ref_text, "r", encoding="utf-8") as f: for line in f: utt_id, text = line.strip().split(None, 1) assert utt_id not in refs, utt_id refs[utt_id] = text wer_counter = Counter() - with open(args.hyp_text, 'r', encoding='utf-8') as f: + with open(args.hyp_text, "r", encoding="utf-8") as f: for line in f: utt_id, text = line.strip().split(None, 1) assert utt_id in refs, utt_id @@ -91,19 +91,19 @@ def main(args): _, _, counter = edit_distance(ref_list, hyp_list) wer_counter += counter - assert wer_counter['words'] > 0 + assert wer_counter["words"] > 0 wer = float( - wer_counter['sub'] + wer_counter['ins'] + wer_counter['del'] - ) / wer_counter['words'] * 100 - sub = float(wer_counter['sub']) / wer_counter['words'] * 100 - ins = float(wer_counter['ins']) / wer_counter['words'] * 100 - dlt = float(wer_counter['del']) / wer_counter['words'] * 100 + wer_counter["sub"] + wer_counter["ins"] + wer_counter["del"] + ) / wer_counter["words"] * 100 + sub = float(wer_counter["sub"]) / wer_counter["words"] * 100 + ins = float(wer_counter["ins"]) / wer_counter["words"] * 100 + dlt = float(wer_counter["del"]) / wer_counter["words"] * 100 - print('WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}'.format( - wer, sub, ins, dlt, wer_counter['words'])) + print("WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}".format( + wer, sub, ins, dlt, wer_counter["words"])) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args) diff --git a/espresso/tools/estimate_initial_state_prior_from_alignments.py b/espresso/tools/estimate_initial_state_prior_from_alignments.py index a1d111106..9f48da1ef 100755 --- a/espresso/tools/estimate_initial_state_prior_from_alignments.py +++ b/espresso/tools/estimate_initial_state_prior_from_alignments.py @@ -13,7 +13,7 @@ try: import kaldi_io except ImportError: - raise ImportError('Please install kaldi_io with: pip install kaldi_io') + raise ImportError("Please install kaldi_io with: pip install kaldi_io") logging.basicConfig( diff --git a/espresso/tools/lexical_prefix_tree.py b/espresso/tools/lexical_prefix_tree.py index 79a0e8f38..d97281404 100644 --- a/espresso/tools/lexical_prefix_tree.py +++ b/espresso/tools/lexical_prefix_tree.py @@ -24,8 +24,8 @@ def lexical_prefix_tree( Return: root (Node): the root of the prefix tree, where each node has the fields: - ('children': Dict[int,Node], 'word_idx': int, 'word_set': Tuple[int]). - 'children' is subword_idx -> node, and 'word_set' is (first-1, last), + ("children": Dict[int,Node], "word_idx": int, "word_set": Tuple[int]). + "children" is subword_idx -> node, and "word_set" is (first-1, last), where [first, last] is the range of the word indexes (inclusive) in the word dictionary who share the same prefix at that node. We assume words in the word dictionary are in lexical order. @@ -43,8 +43,11 @@ def __init__(self, children={}, word_idx=-1, word_set=None): for widx in range(len(word_dict)): if widx not in special_symbols: # skip , , # tokenize a word into a list of subwords - subwords = subword_tokenizer(word_dict[widx]) \ - if subword_tokenizer is not None else list(word_dict[widx]) + subwords = ( + subword_tokenizer(word_dict[widx]) + if subword_tokenizer is not None + else list(word_dict[widx]) + ) if any(subword_dict.index(s) == subword_dict.unk() for s in subwords): # skip words containing any unknown subwords continue diff --git a/espresso/tools/text2token.py b/espresso/tools/text2token.py index 455d8ebe1..2f3169365 100755 --- a/espresso/tools/text2token.py +++ b/espresso/tools/text2token.py @@ -12,20 +12,27 @@ def get_parser(): parser = argparse.ArgumentParser( - description='Convert transcripts into tokens and write them to stdout') + description="Convert transcripts into tokens and write them to stdout" + ) # fmt: off - parser.add_argument('--skip-ncols', default=0, type=int, - help='skip first n columns') - parser.add_argument('--space', default='', type=str, - help='space symbol') - parser.add_argument('--ends-with-space', default=True, type=bool, - help='Whether to append to the end of each ' - 'tokenized sentence.') - parser.add_argument('--non-lang-syms', default=None, type=str, - help='path to a file listing non-linguistic symbols, ' - 'e.g., etc. One entry per line.') - parser.add_argument('text', type=str, nargs='?', - help='input text') + parser.add_argument( + "--skip-ncols", default=0, type=int, help="skip first n columns" + ) + parser.add_argument( + "--space", default="", type=str, help="space symbol" + ) + parser.add_argument( + "--ends-with-space", default=True, type=bool, + help="whether to append to the end of each tokenized sentence." + ) + parser.add_argument( + "--non-lang-syms", default=None, type=str, + help="path to a file listing non-linguistic symbols, " + "e.g., etc. One entry per line." + ) + parser.add_argument( + "text", type=str, nargs="?", help="input text" + ) # fmt: on return parser @@ -34,29 +41,29 @@ def get_parser(): def main(args): nls = None if args.non_lang_syms is not None: - with open(args.non_lang_syms, 'r', encoding='utf-8') as f: + with open(args.non_lang_syms, "r", encoding="utf-8") as f: nls = [x.rstrip() for x in f.readlines()] - with (open(args.text, 'r', encoding='utf-8') if args.text else sys.stdin) as f: + with (open(args.text, "r", encoding="utf-8") if args.text else sys.stdin) as f: for line in f: entry = line.rstrip().split() tokenized = tokenize( - ' '.join(entry[args.skip_ncols:]), + " ".join(entry[args.skip_ncols:]), space=args.space, non_lang_syms=nls, ) if args.skip_ncols > 0: if args.ends_with_space: - print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized + ' ' + args.space) + print(" ".join(entry[: args.skip_ncols]) + " " + tokenized + " " + args.space) else: - print(' '.join(entry[:args.skip_ncols]) + ' ' + tokenized) + print(" ".join(entry[: args.skip_ncols]) + " " + tokenized) else: if args.ends_with_space: - print(tokenized + ' ' + args.space) + print(tokenized + " " + args.space) else: print(tokenized) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args) diff --git a/espresso/tools/text2vocabulary.py b/espresso/tools/text2vocabulary.py index 047d43957..2a054e4f9 100755 --- a/espresso/tools/text2vocabulary.py +++ b/espresso/tools/text2vocabulary.py @@ -12,50 +12,56 @@ logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, stream=sys.stderr, ) -logger = logging.getLogger('espresso.tools.text2vocabulary') +logger = logging.getLogger("espresso.tools.text2vocabulary") def get_parser(): - parser = argparse.ArgumentParser( - description='Create a vocabulary from text files') + parser = argparse.ArgumentParser(description="Create a vocabulary from text files") # fmt: off - parser.add_argument('--skip-ncols', default=0, type=int, - help='skip first n columns') - parser.add_argument('--cutoff', default=0, type=int, - help='cut-off frequency') - parser.add_argument('--vocabsize', default=20000, type=int, - help='vocabulary size') - parser.add_argument('--exclude', type=str, default=None, - help='space separated, list of excluding words, ' - 'e.g., etc.') - parser.add_argument('--vocab', type=str, default=None, - help='path to the vocabulary file. If not None, calculate' - 'OOV stats with the provided vocabulary and output the ' - 'same vocabulary with word counts') - parser.add_argument('--valid-text', type=str, default=None, - help='path to the validation text file') - parser.add_argument('--test-text', type=str, default=None, - help='colon separated paths to the test text file(s)') - parser.add_argument('text_files', nargs='*', - help='input text files') + parser.add_argument( + "--skip-ncols", default=0, type=int, help="skip first n columns" + ) + parser.add_argument( + "--cutoff", default=0, type=int, help="cut-off frequency" + ) + parser.add_argument( + "--vocabsize", default=20000, type=int, help="vocabulary size" + ) + parser.add_argument( + "--exclude", type=str, default=None, + help="space separated, list of excluding words, e.g., etc." + ) + parser.add_argument( + "--vocab", type=str, default=None, + help="path to the vocabulary file. If not None, calculate OOV stats with " + "the provided vocabulary and output the same vocabulary with word counts" + ) + parser.add_argument( + "--valid-text", type=str, default=None, help="path to the validation text file" + ) + parser.add_argument( + "--test-text", type=str, default=None, + help="colon separated paths to the test text file(s)" + ) + parser.add_argument("text_files", nargs="*", help="input text files") # fmt: on return parser def main(args): - exclude = args.exclude.split(' ') if args.exclude is not None else [] + exclude = args.exclude.split(" ") if args.exclude is not None else [] if len(args.text_files) == 0: - args.text_files.append('-') + args.text_files.append("-") counter = Counter() for fn in args.text_files: - with (open(fn, 'r', encoding='utf-8') if fn != '-' else sys.stdin) as f: + with (open(fn, "r", encoding="utf-8") if fn != "-" else sys.stdin) as f: for line in f: tokens = line.rstrip().split()[args.skip_ncols:] tokens = [tok for tok in tokens if tok not in exclude] @@ -75,8 +81,8 @@ def main(args): most_common = most_common[:cutoff_point] vocab_set = set(list(zip(*most_common))[0]) else: - logger.info('using the provided vocabulary:') - with open(args.vocab, 'r', encoding='utf-8') as f: + logger.info("using the provided vocabulary:") + with open(args.vocab, "r", encoding="utf-8") as f: vocab_set = set([line.rstrip().split()[0] for line in f]) most_common = [] for word in vocab_set: @@ -85,46 +91,46 @@ def main(args): # words in vocabulary are lexically sorted for w, c in sorted(most_common, key=lambda x: x[0]): - print('{} {:d}'.format(w, c)) + print("{} {:d}".format(w, c)) - oov_rate = 1. - float(invocab_count) / total_count - logger.info('training set:') - logger.info(' total #tokens={:d}'.format(total_count)) - logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) + oov_rate = 1.0 - float(invocab_count) / total_count + logger.info("training set:") + logger.info(" total #tokens={:d}".format(total_count)) + logger.info(" OOV rate={:.2f}%".format(oov_rate * 100)) if args.vocab is None: - logger.info(' cutoff frequency={:d}'.format(cutoff_freq)) + logger.info(" cutoff frequency={:d}".format(cutoff_freq)) if args.valid_text is not None: total_count = 0 invocab_count = 0 - with open(args.valid_text, 'r', encoding='utf-8') as f: + with open(args.valid_text, "r", encoding="utf-8") as f: for line in f: tokens = line.rstrip().split()[args.skip_ncols:] tokens = [tok for tok in tokens if tok not in exclude] total_count += len(tokens) invocab_count += len([tok for tok in tokens if tok in vocab_set]) - oov_rate = 1. - float(invocab_count) / total_count - logger.info('validation set:') - logger.info(' total #tokens={:d}'.format(total_count)) - logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) + oov_rate = 1.0 - float(invocab_count) / total_count + logger.info("validation set:") + logger.info(" total #tokens={:d}".format(total_count)) + logger.info(" OOV rate={:.2f}%".format(oov_rate * 100)) if args.test_text is not None: for k, path in enumerate(args.test_text.split(os.pathsep)): total_count = 0 invocab_count = 0 - with open(path, 'r', encoding='utf-8') as f: + with open(path, "r", encoding="utf-8") as f: for line in f: tokens = line.rstrip().split()[args.skip_ncols:] tokens = [tok for tok in tokens if tok not in exclude] total_count += len(tokens) invocab_count += len([tok for tok in tokens if tok in vocab_set]) - oov_rate = 1. - float(invocab_count) / total_count - logger.info('test set{}:'.format(k) if k > 0 else 'test set:') - logger.info(' total #tokens={:d}'.format(total_count)) - logger.info(' OOV rate={:.2f}%'.format(oov_rate * 100)) + oov_rate = 1.0 - float(invocab_count) / total_count + logger.info("test set{}:".format(k) if k > 0 else "test set:") + logger.info(" total #tokens={:d}".format(total_count)) + logger.info(" OOV rate={:.2f}%".format(oov_rate * 100)) -if __name__ == '__main__': +if __name__ == "__main__": parser = get_parser() args = parser.parse_args() main(args) diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index 19fe51ab5..0ee0b728d 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -13,15 +13,15 @@ from fairseq import utils -def tokenize(sent, space='', non_lang_syms=None): +def tokenize(sent, space="", non_lang_syms=None): assert isinstance(sent, str) - sent = ' '.join(sent.strip().split()) + sent = " ".join(sent.strip().split()) match_pos = [] if non_lang_syms is not None: assert isinstance(non_lang_syms, list) if len(non_lang_syms) > 0: - prog = re.compile('|'.join(map(re.escape, non_lang_syms))) + prog = re.compile("|".join(map(re.escape, non_lang_syms))) matches = prog.finditer(sent) for match in matches: match_pos.append([match.start(), match.end()]) @@ -34,8 +34,8 @@ def tokenize(sent, space='', non_lang_syms=None): i = end_pos tokens.extend([token for token in sent[i:]]) - tokens = [space if token == ' ' else token for token in tokens] - return ' '.join(tokens) + tokens = [space if token == " " else token for token in tokens] + return " ".join(tokens) def collate_frames(values, pad_value=0.0, left_pad=False, pad_to_length=None, pad_to_multiple=1): @@ -119,7 +119,7 @@ def plot_attention(attention, hypo_sent, utt_id, save_dir): """ try: import matplotlib as mpl - mpl.use('Agg') + mpl.use("Agg") import matplotlib.pyplot as plt except ImportError: raise ImportError( @@ -131,8 +131,8 @@ def plot_attention(attention, hypo_sent, utt_id, save_dir): attn = attention.data.numpy() plt.matshow(attn) plt.title(hypo_sent, fontsize=8) - filename = os.path.join(save_dir, utt_id + '.pdf') - plt.savefig(filename, bbox_inches='tight') + filename = os.path.join(save_dir, utt_id + ".pdf") + plt.savefig(filename, bbox_inches="tight") plt.close() @@ -149,8 +149,8 @@ def edit_distance(ref, hyp): dist: edit distance matrix of size len(ref) x len(hyp) steps: list of edit steps counter: object of collections.Counter containing counts of - reference words ('words'), number of correct words ('corr'), - substitutions ('sub'), insertions ('ins'), deletions ('del'). + reference words ("words"), number of correct words ("corr"), + substitutions ("sub"), insertions ("ins"), deletions ("del"). """ assert isinstance(ref, list) and isinstance(hyp, list) @@ -182,23 +182,23 @@ def edit_distance(ref, hyp): i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] and ref[i - 1] == hyp[j - 1] ): - steps.append('corr') + steps.append("corr") i, j = i - 1, j - 1 elif i >= 1 and j >= 1 and dist[i][j] == dist[i - 1][j - 1] + 1: assert ref[i - 1] != hyp[j - 1] - steps.append('sub') + steps.append("sub") i, j = i - 1, j - 1 elif j >= 1 and dist[i][j] == dist[i][j - 1] + 1: - steps.append('ins') + steps.append("ins") j = j - 1 else: assert i >= 1 and dist[i][j] == dist[i - 1][j] + 1 - steps.append('del') + steps.append("del") i = i - 1 steps = steps[::-1] counter = Counter( - {'words': len(ref), 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0} + {"words": len(ref), "corr": 0, "sub": 0, "ins": 0, "del": 0} ) counter.update(steps) @@ -212,7 +212,7 @@ def aligned_print(ref, hyp, steps): Args: ref: list of words obtained by splitting reference sentence string hyp: list of words obtained by splitting hypothesis sentence string - steps: list of edit steps with elements 'corr', 'sub', 'ins' or 'del'. + steps: list of edit steps with elements "corr", "sub", "ins" or "del". Return: out_str: aligned reference and hypothesis string with edit steps. @@ -223,70 +223,76 @@ def aligned_print(ref, hyp, steps): if len(steps) == 0: # in case both ref and hyp are empty assert len(ref) == 0 and len(hyp) == 0 - out_str = 'REF: \nHYP: \nSTP: \nWER: {:.2f}%\n\n'.format(0.) + out_str = "REF: \nHYP: \nSTP: \nWER: {:.2f}%\n\n".format(0.0) return out_str - out_str = 'REF: ' + out_str = "REF: " for i in range(len(steps)): - delim = ' ' if i < len(steps) - 1 else '\n' - if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') - hyp_idx = i - steps[:i].count('del') + delim = " " if i < len(steps) - 1 else "\n" + if steps[i] == "sub": + ref_idx = i - steps[: i].count("ins") + hyp_idx = i - steps[: i].count("del") if len(ref[ref_idx]) < len(hyp[hyp_idx]): - out_str += ref[ref_idx] + \ - ' ' * (len(hyp[hyp_idx]) - len(ref[ref_idx])) + delim + out_str += ( + ref[ref_idx] + " " * (len(hyp[hyp_idx]) - len(ref[ref_idx])) + delim + ) else: out_str += ref[ref_idx] + delim - elif steps[i] == 'ins': - idx = i - steps[:i].count('del') - out_str += ' ' * len(hyp[idx]) + delim + elif steps[i] == "ins": + idx = i - steps[: i].count("del") + out_str += " " * len(hyp[idx]) + delim else: - assert steps[i] == 'del' or steps[i] == 'corr' - idx = i - steps[:i].count('ins') + assert steps[i] == "del" or steps[i] == "corr" + idx = i - steps[: i].count("ins") out_str += ref[idx] + delim - out_str += 'HYP: ' + out_str += "HYP: " for i in range(len(steps)): - delim = ' ' if i < len(steps) - 1 else '\n' - if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') - hyp_idx = i - steps[:i].count('del') + delim = " " if i < len(steps) - 1 else "\n" + if steps[i] == "sub": + ref_idx = i - steps[: i].count("ins") + hyp_idx = i - steps[: i].count("del") if len(ref[ref_idx]) > len(hyp[hyp_idx]): - out_str += hyp[hyp_idx] + \ - ' ' * (len(ref[ref_idx]) - len(hyp[hyp_idx])) + delim + out_str += ( + hyp[hyp_idx] + " " * (len(ref[ref_idx]) - len(hyp[hyp_idx])) + + delim + ) else: out_str += hyp[hyp_idx] + delim - elif steps[i] == 'del': - idx = i - steps[:i].count('ins') - out_str += ' ' * len(ref[idx]) + delim + elif steps[i] == "del": + idx = i - steps[: i].count("ins") + out_str += " " * len(ref[idx]) + delim else: - assert steps[i] == 'ins' or steps[i] == 'corr' - idx = i - steps[:i].count('del') + assert steps[i] == "ins" or steps[i] == "corr" + idx = i - steps[: i].count("del") out_str += hyp[idx] + delim - out_str += 'STP: ' + out_str += "STP: " for i in range(len(steps)): - delim = ' ' if i < len(steps) - 1 else '\n' - if steps[i] == 'sub': - ref_idx = i - steps[:i].count('ins') - hyp_idx = i - steps[:i].count('del') + delim = " " if i < len(steps) - 1 else "\n" + if steps[i] == "sub": + ref_idx = i - steps[: i].count("ins") + hyp_idx = i - steps[: i].count("del") if len(ref[ref_idx]) > len(hyp[hyp_idx]): - out_str += 'S' + ' ' * (len(ref[ref_idx]) - 1) + delim + out_str += "S" + " " * (len(ref[ref_idx]) - 1) + delim else: - out_str += 'S' + ' ' * (len(hyp[hyp_idx]) - 1) + delim - elif steps[i] == 'ins': - idx = i - steps[:i].count('del') - out_str += 'I' + ' ' * (len(hyp[idx]) - 1) + delim + out_str += "S" + " " * (len(hyp[hyp_idx]) - 1) + delim + elif steps[i] == "ins": + idx = i - steps[: i].count("del") + out_str += "I" + " " * (len(hyp[idx]) - 1) + delim else: - assert steps[i] == 'del' or steps[i] == 'corr' - idx = i - steps[:i].count('ins') - sym = 'D' if steps[i] == 'del' else ' ' - out_str += sym + ' ' * (len(ref[idx]) - 1) + delim + assert steps[i] == "del" or steps[i] == "corr" + idx = i - steps[: i].count("ins") + sym = "D" if steps[i] == "del" else " " + out_str += sym + " " * (len(ref[idx]) - 1) + delim counter = Counter(steps) - wer = float(counter['sub'] + counter['ins'] + counter['del']) / len(ref) \ - * 100 if len(ref) > 0 else 0. - out_str += 'WER: ' + '{:.2f}%'.format(wer) + '\n' - out_str += '\n' + wer = ( + float(counter["sub"] + counter["ins"] + counter["del"]) / len(ref) * 100 + if len(ref) > 0 + else 0.0 + ) + out_str += "WER: " + "{:.2f}%".format(wer) + "\n" + out_str += "\n" return out_str diff --git a/espresso/tools/wer.py b/espresso/tools/wer.py index 5314682d9..839a9b0c2 100644 --- a/espresso/tools/wer.py +++ b/espresso/tools/wer.py @@ -32,52 +32,54 @@ def reset(self): def parse_wer_output_filter(self, wer_output_filter): if wer_output_filter: - with open(PathManager.get_local_path(wer_output_filter), 'r', encoding='utf-8') as f: + with open(PathManager.get_local_path(wer_output_filter), "r", encoding="utf-8") as f: for line in f: line = line.strip() - if line.startswith('#!') or line == '': + if line.startswith("#!") or line == "": continue - elif line.startswith('s/'): - m = re.match(r's/(.+)/(.*)/g', line) + elif line.startswith("s/"): + m = re.match(r"s/(.+)/(.*)/g", line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) - elif line.startswith('s:'): - m = re.match(r's:(.+):(.*):g', line) + elif line.startswith("s:"): + m = re.match(r"s:(.+):(.*):g", line) assert m is not None self.word_filters.append([m.group(1), m.group(2)]) else: - logger.warning('Unsupported pattern: "{}". Ignoring it'.format(line)) + logger.warning("Unsupported pattern: '{}'. Ignoring it".format(line)) def add_prediction(self, utt_id, pred): if not isinstance(utt_id, str): - raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) + raise TypeError("utt_id must be a string(got {})".format(type(utt_id))) if not isinstance(pred, str): - raise TypeError('pred must be a string(got {})'.format(type(pred))) + raise TypeError("pred must be a string(got {})".format(type(pred))) - assert utt_id not in self.char_results, \ - 'Duplicated utterance id detected: {}'.format(utt_id) - self.char_results[utt_id] = pred + '\n' + assert ( + utt_id not in self.char_results + ), "Duplicated utterance id detected: {}".format(utt_id) + self.char_results[utt_id] = pred + "\n" pred_words = self.dictionary.wordpiece_decode(pred) - assert utt_id not in self.results, \ - 'Duplicated utterance id detected: {}'.format(utt_id) - self.results[utt_id] = pred_words + '\n' + assert ( + utt_id not in self.results + ), "Duplicated utterance id detected: {}".format(utt_id) + self.results[utt_id] = pred_words + "\n" def add_evaluation(self, utt_id, ref, pred): if not isinstance(utt_id, str): - raise TypeError('utt_id must be a string(got {})'.format(type(utt_id))) + raise TypeError("utt_id must be a string(got {})".format(type(utt_id))) if not isinstance(ref, str): - raise TypeError('ref must be a string (got {})'.format(type(ref))) + raise TypeError("ref must be a string (got {})".format(type(ref))) if not isinstance(pred, str): - raise TypeError('pred must be a string(got {})'.format(type(pred))) + raise TypeError("pred must be a string(got {})".format(type(pred))) # filter out any non_lang_syms from ref and pred - non_lang_syms = getattr(self.dictionary, 'non_lang_syms', None) + non_lang_syms = getattr(self.dictionary, "non_lang_syms", None) assert non_lang_syms is None or isinstance(non_lang_syms, list) if non_lang_syms is not None and len(non_lang_syms) > 0: ref_list, pred_list = ref.strip().split(), pred.strip().split() - ref = ' '.join([x for x in ref_list if x not in non_lang_syms]) - pred = ' '.join([x for x in pred_list if x not in non_lang_syms]) + ref = " ".join([x for x in ref_list if x not in non_lang_syms]) + pred = " ".join([x for x in pred_list if x not in non_lang_syms]) # char level counts _, _, counter = speech_utils.edit_distance( @@ -99,45 +101,48 @@ def add_evaluation(self, utt_id, ref, pred): ref_word_list, pred_word_list, ) self.word_counter += counter - assert utt_id not in self.aligned_results, \ - 'Duplicated utterance id detected: {}'.format(utt_id) + assert ( + utt_id not in self.aligned_results + ), "Duplicated utterance id detected: {}".format(utt_id) self.aligned_results[utt_id] = speech_utils.aligned_print( ref_word_list, pred_word_list, steps, ) def cer(self): - assert self.char_counter['words'] > 0 + assert self.char_counter["words"] > 0 cer = float( - self.char_counter['sub'] + self.char_counter['ins'] + self.char_counter['del'] - ) / self.char_counter['words'] * 100 - sub = float(self.char_counter['sub']) / self.char_counter['words'] * 100 - ins = float(self.char_counter['ins']) / self.char_counter['words'] * 100 - dlt = float(self.char_counter['del']) / self.char_counter['words'] * 100 + self.char_counter["sub"] + self.char_counter["ins"] + self.char_counter["del"] + ) / self.char_counter["words"] * 100 + sub = float(self.char_counter["sub"]) / self.char_counter["words"] * 100 + ins = float(self.char_counter["ins"]) / self.char_counter["words"] * 100 + dlt = float(self.char_counter["del"]) / self.char_counter["words"] * 100 return cer, sub, ins, dlt def wer(self): - assert self.word_counter['words'] > 0 + assert self.word_counter["words"] > 0 wer = float( - self.word_counter['sub'] + self.word_counter['ins'] + self.word_counter['del'] - ) / self.word_counter['words'] * 100 - sub = float(self.word_counter['sub']) / self.word_counter['words'] * 100 - ins = float(self.word_counter['ins']) / self.word_counter['words'] * 100 - dlt = float(self.word_counter['del']) / self.word_counter['words'] * 100 + self.word_counter["sub"] + self.word_counter["ins"] + self.word_counter["del"] + ) / self.word_counter["words"] * 100 + sub = float(self.word_counter["sub"]) / self.word_counter["words"] * 100 + ins = float(self.word_counter["ins"]) / self.word_counter["words"] * 100 + dlt = float(self.word_counter["del"]) / self.word_counter["words"] * 100 return wer, sub, ins, dlt def tot_word_error(self): - return self.word_counter['sub'] + self.word_counter['ins'] + \ - self.word_counter['del'] + return ( + self.word_counter["sub"] + self.word_counter["ins"] + self.word_counter["del"] + ) def tot_word_count(self): - return self.word_counter['words'] + return self.word_counter["words"] def tot_char_error(self): - return self.char_counter['sub'] + self.char_counter['ins'] + \ - self.char_counter['del'] + return ( + self.char_counter["sub"] + self.char_counter["ins"] + self.char_counter["del"] + ) def tot_char_count(self): - return self.char_counter['words'] + return self.char_counter["words"] def add_ordered_utt_list(self, *args): if len(args) == 1 and isinstance(args[0], list): # aleady a list of utterance ids @@ -145,7 +150,7 @@ def add_ordered_utt_list(self, *args): return self.ordered_utt_list = [] for text_file in args: - with open(PathManager.get_local_path(text_file), 'r', encoding='utf-8') as f: + with open(PathManager.get_local_path(text_file), "r", encoding="utf-8") as f: one_utt_list = [line.strip().split()[0] for line in f] self.ordered_utt_list.extend(one_utt_list) if len(self.char_results): @@ -156,34 +161,34 @@ def add_ordered_utt_list(self, *args): assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) def print_char_results(self): - res = '' + res = "" if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.char_results.keys()) for utt_id in self.ordered_utt_list: - res += utt_id + ' ' + self.char_results[utt_id] + res += utt_id + " " + self.char_results[utt_id] else: for utt_id in self.char_results: - res += utt_id + ' ' + self.char_results[utt_id] + res += utt_id + " " + self.char_results[utt_id] return res def print_results(self): - res = '' + res = "" if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.results.keys()) for utt_id in self.ordered_utt_list: - res += utt_id + ' ' + self.results[utt_id] + res += utt_id + " " + self.results[utt_id] else: for utt_id in self.results: - res += utt_id + ' ' + self.results[utt_id] + res += utt_id + " " + self.results[utt_id] return res def print_aligned_results(self): - res = '' + res = "" if self.ordered_utt_list is not None: assert set(self.ordered_utt_list) == set(self.aligned_results.keys()) for utt_id in self.ordered_utt_list: - res += utt_id + '\n' + self.aligned_results[utt_id] + res += utt_id + "\n" + self.aligned_results[utt_id] else: for utt_id in self.aligned_results: - res += utt_id + '\n' + self.aligned_results[utt_id] + res += utt_id + "\n" + self.aligned_results[utt_id] return res diff --git a/examples/asr_librispeech/run.sh b/examples/asr_librispeech/run.sh index 13fc0e98e..786acca27 100755 --- a/examples/asr_librispeech/run.sh +++ b/examples/asr_librispeech/run.sh @@ -153,7 +153,7 @@ if [ ${stage} -le 4 ]; then for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -173,7 +173,7 @@ if [ ${stage} -le 5 ]; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((16000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 32000 --batch-size 1024 --curriculum 1 \ @@ -194,7 +194,7 @@ if [ ${stage} -le 6 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log - python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --batch-size 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -245,7 +245,7 @@ if [ ${stage} -le 8 ]; then opts="$opts --max-epoch 30 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 10" fi fi - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 \ --log-interval $((8000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --batch-size 24 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 48 --ddp-backend no_c10d --update-freq $update_freq \ @@ -276,7 +276,7 @@ if [ ${stage} -le 9 ]; then for dataset in $test_set; do decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 15000 --batch-size 24 \ + --task speech_recognition_espresso --max-tokens 15000 --batch-size 24 \ --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 60 --max-len-a 0.08 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index 95ab5403a..f3be25a98 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -191,7 +191,7 @@ if [ $stage -le 3 ]; then test_paths= && for dataset in $test_set; do test_paths="$test_paths $lmdatadir/$dataset.tokens"; done test_paths=$(echo $test_paths | awk '{$1=$1;print}' | tr ' ' ',') ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 50 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -211,7 +211,7 @@ if [ $stage -le 4 ]; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((1000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --batch-size 1024 \ @@ -232,7 +232,7 @@ if [ $stage -le 5 ]; then test_set_array=($test_set) for i in $(seq 0 $num); do log_file=$lmdir/log/evaluation_${test_set_array[$i]}.log - python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset ${gen_set_array[$i]} \ --max-tokens 40960 --batch-size 1536 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -284,7 +284,7 @@ if [ $stage -le 7 ]; then opts="$opts --max-epoch 35 --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 --start-reduce-lr-epoch 14" fi fi - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 \ --log-interval $((3000/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((4000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 26000 --batch-size 48 --curriculum 2 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ @@ -313,7 +313,7 @@ if [ $stage -le 8 ]; then for dataset in $test_set; do decode_dir=$dir/decode_${dataset}${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 24000 --batch-size 48 \ + --task speech_recognition_espresso --max-tokens 24000 --batch-size 48 \ --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --non-lang-syms $nlsyms --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 35 --max-len-a 0.1 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index ae3ea5fe7..c0bba6d61 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -158,7 +158,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing char text..." mkdir -p $lmdatadir/log ${decode_cmd} $lmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 30 --srcdict $lmdict --only-source \ --trainpref $lmdatadir/train.tokens \ --validpref $lmdatadir/$valid_set.tokens \ @@ -168,7 +168,7 @@ if [ ${stage} -le 3 ]; then echo "$0: binarizing word text..." mkdir -p $wordlmdatadir/log ${decode_cmd} $wordlmdatadir/log/preprocess.log \ - python3 ../../fairseq_cli/preprocess.py --user-dir espresso --task language_modeling_for_asr \ + python3 ../../fairseq_cli/preprocess.py --task language_modeling_for_asr \ --workers 30 --srcdict $wordlmdict --only-source \ --trainpref $wordlmdatadir/train \ --validpref $wordlmdatadir/$valid_set \ @@ -189,7 +189,7 @@ if [ ${stage} -le 4 ] && ! $use_wordlm; then mkdir -p $lmdir/log log_file=$lmdir/log/train.log [ -f $lmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $lmdatadir --seed 1 \ --task language_modeling_for_asr --dict $lmdict \ --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 25600 --batch-size 128 \ @@ -206,7 +206,7 @@ if [ ${stage} -le 5 ] && ! $use_wordlm; then echo "Stage 5: char LM Evaluation" for gen_subset in valid test; do log_file=$lmdir/log/evaluation_$gen_subset.log - python3 ../../fairseq_cli/eval_lm.py $lmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $lmdatadir --cpu \ --task language_modeling_for_asr --dict $lmdict --gen-subset $gen_subset \ --max-tokens 192000 --batch-size 256 --sample-break-mode eos \ --path $lmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -219,7 +219,7 @@ if [ ${stage} -le 6 ] && $use_wordlm; then mkdir -p $wordlmdir/log log_file=$wordlmdir/log/train.log [ -f $wordlmdir/checkpoint_last.pt ] && log_file="-a $log_file" - CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $wordlmdatadir --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu python3 ../../fairseq_cli/train.py $wordlmdatadir --seed 1 \ --task language_modeling_for_asr --dict $wordlmdict \ --log-interval $((4000/ngpus)) --log-format simple \ --num-workers 0 --max-tokens 6400 --batch-size 256 \ @@ -237,7 +237,7 @@ if [ ${stage} -le 7 ] && $use_wordlm; then echo "Stage 7: word LM Evaluation" for gen_subset in valid test; do log_file=$wordlmdir/log/evaluation_$gen_subset.log - python3 ../../fairseq_cli/eval_lm.py $wordlmdatadir --user-dir espresso --cpu \ + python3 ../../fairseq_cli/eval_lm.py $wordlmdatadir --cpu \ --task language_modeling_for_asr --dict $wordlmdict --gen-subset $gen_subset \ --max-tokens 12800 --batch-size 512 --sample-break-mode eos \ --path $wordlmdir/$lm_checkpoint 2>&1 | tee $log_file @@ -283,7 +283,7 @@ if [ ${stage} -le 9 ]; then opts="$opts --lr-shrink 0.5 --start-reduce-lr-epoch 11" opts="$opts --scheduled-sampling-probs 0.5 --start-scheduled-sampling-epoch 6" fi - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_espresso --seed 1 \ --log-interval $((800/ngpus/update_freq)) --log-format simple --print-training-sample-interval $((2000/ngpus/update_freq)) \ --num-workers 0 --data-buffer-size 0 --max-tokens 24000 --batch-size 32 --curriculum 2 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 64 --ddp-backend no_c10d --update-freq $update_freq \ @@ -316,7 +316,7 @@ if [ ${stage} -le 10 ]; then for dataset in $valid_set $test_set; do decode_dir=$dir/decode_$dataset${decode_affix:+_${decode_affix}} CUDA_VISIBLE_DEVICES=$(echo $free_gpu | sed 's/,/ /g' | awk '{print $1}') speech_recognize.py data \ - --task speech_recognition_espresso --user-dir espresso --max-tokens 20000 --batch-size 32 \ + --task speech_recognition_espresso --max-tokens 20000 --batch-size 32 \ --num-shards 1 --shard-id 0 --dict $dict --bpe characters_asr --non-lang-syms $nlsyms \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ diff --git a/examples/asr_wsj/run_chain_e2e.sh b/examples/asr_wsj/run_chain_e2e.sh index 3d524db80..c63240af0 100755 --- a/examples/asr_wsj/run_chain_e2e.sh +++ b/examples/asr_wsj/run_chain_e2e.sh @@ -185,7 +185,7 @@ if [ ${stage} -le 6 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" update_freq=1 - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e --task speech_recognition_hybrid --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e --task speech_recognition_hybrid --seed 1 \ --log-interval $((200/ngpus/update_freq)) --log-format simple \ --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --batch-size 128 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ @@ -212,7 +212,7 @@ if [ ${stage} -le 7 ]; then for lmtype in tgpr bd_tgpr; do graph_dir=$tree_dir/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ - dump_posteriors.py data/chain_e2e --cpu --task speech_recognition_hybrid --user-dir espresso \ + dump_posteriors.py data/chain_e2e --cpu --task speech_recognition_hybrid \ --max-tokens 120000 --batch-size 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB \ --max-source-positions 9999 --path $path \| \ diff --git a/examples/asr_wsj/run_chain_e2e_bichar.sh b/examples/asr_wsj/run_chain_e2e_bichar.sh index 8829f28e6..23ead11f7 100755 --- a/examples/asr_wsj/run_chain_e2e_bichar.sh +++ b/examples/asr_wsj/run_chain_e2e_bichar.sh @@ -185,7 +185,7 @@ if [ ${stage} -le 6 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" update_freq=1 - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e_bichar --task speech_recognition_hybrid --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/chain_e2e_bichar --task speech_recognition_hybrid --seed 1 \ --log-interval $((200/ngpus/update_freq)) --log-format simple \ --num-workers 0 --data-buffer-size 0 --max-tokens 120000 --batch-size 128 --curriculum 1 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ @@ -212,7 +212,7 @@ if [ ${stage} -le 7 ]; then for lmtype in tgpr bd_tgpr; do graph_dir=$tree_dir/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ - dump_posteriors.py data/chain_e2e_bichar --cpu --task speech_recognition_hybrid --user-dir espresso \ + dump_posteriors.py data/chain_e2e_bichar --cpu --task speech_recognition_hybrid \ --max-tokens 120000 --batch-size 128 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB \ --max-source-positions 9999 --path $path \| \ diff --git a/examples/asr_wsj/run_xent.sh b/examples/asr_wsj/run_xent.sh index d66b937a3..e58fdc466 100755 --- a/examples/asr_wsj/run_xent.sh +++ b/examples/asr_wsj/run_xent.sh @@ -165,7 +165,7 @@ if [ ${stage} -le 5 ]; then log_file=$dir/log/train.log [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" update_freq=1 - CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/xent --task speech_recognition_hybrid --seed 1 --user-dir espresso \ + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data/xent --task speech_recognition_hybrid --seed 1 \ --log-interval $((100/ngpus/update_freq)) --log-format simple \ --num-workers 0 --data-buffer-size 0 --max-tokens 160000 --batch-size 256 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 256 --ddp-backend no_c10d --update-freq $update_freq \ @@ -192,10 +192,10 @@ if [ ${stage} -le 6 ]; then for lmtype in tgpr bd_tgpr; do graph_dir=exp/$gmm/graph_${lmtype} $decode_cmd $queue_opt JOB=1:$nj $dir/decode_${lmtype}_${data_affix}/log/decode.JOB.log \ - dump_posteriors.py data/xent --cpu --task speech_recognition_hybrid --user-dir espresso \ + dump_posteriors.py data/xent --cpu --task speech_recognition_hybrid \ --max-tokens 256000 --batch-size 256 --num-shards 1 --shard-id 0 --num-targets $num_targets \ --gen-subset $dataset.JOB --chunk-width 150 --chunk-left-context 10 --chunk-right-context 10 --label-delay -3 \ - --max-source-positions 9999 --path $path --apply-log-softmax \| \ + --max-source-positions 9999 --path $path --apply-log-softmax True \| \ latgen-faster-mapped --max-active=7000 --min-active=20 --beam=15 --lattice-beam=8 --acoustic-scale=0.1 \ --allow-partial=true --word-symbol-table="$graph_dir/words.txt" \ exp/$gmm/final.mdl $graph_dir/HCLG.fst ark:- "ark:|gzip -c >$dir/decode_${lmtype}_${data_affix}/lat.JOB.gz" || exit 1 diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index caf4a7a2b..0137c1ae8 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -827,6 +827,47 @@ class GenerationConfig(FairseqDataclass): default=False, metadata={"help": "if set, dont use seed for initializing random generators"}, ) + # for espresso.speech_recognize.py + eos_factor: Optional[float] = field( + default=None, + metadata={ + "help": "only consider emitting EOS if its score is no less than " + "the specified factor of the best candidate score" + }, + ) + subwordlm_weight: Optional[float] = field( + default=0.8, + metadata={ + "help": "subword LM weight relative to word LM. Only relevant to " + "MultiLevelLanguageModel as an external LM" + }, + ) + oov_penalty: Optional[float] = field( + default=1e-4, + metadata={"help": "oov penalty with the pretrained external LM"}, + ) + disable_open_vocab: Optional[bool] = field( + default=False, + metadata={ + "help": "whether open vocabulary mode is enabled with the " + "pretrained external LM" + }, + ) + # for espresso.dump_posteriors.py + apply_log_softmax: Optional[bool] = field( + default=False, + metadata={ + "help": "apply log-softmax to the neural network outputs for Xent " + "hybrid systems; otherwise use the raw outputs" + }, + ) + state_prior_file: Optional[str] = field( + default=None, + metadata={ + "help": "state prior file. If provided, use this file instead of " + "that from the checkpoint" + }, + ) @dataclass diff --git a/tests/espresso/test_speech_utils.py b/tests/espresso/test_speech_utils.py index 96173be1e..d00b95c88 100644 --- a/tests/espresso/test_speech_utils.py +++ b/tests/espresso/test_speech_utils.py @@ -13,7 +13,6 @@ import torch from espresso.data import AsrDictionary - import espresso.tools.utils as utils @@ -91,8 +90,11 @@ def test_speech_tokenizer(self): tensor, extra_symbols_to_ignore={self.dictionary.pad()} ) expected_tokens = " ".join( - [token if self.dictionary.index(token) != self.dictionary.unk() else - self.dictionary.unk_word for token in tokens.split(" ")] + [ + token if self.dictionary.index(token) != self.dictionary.unk() + else self.dictionary.unk_word + for token in tokens.split(" ") + ] ) self.assertEqual(reconstructed_tokens, expected_tokens) From 513b171aecf25e3253c065c92e870000d2c5a1c3 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 5 Nov 2020 00:52:02 -0500 Subject: [PATCH 103/119] code adaptation/changes according to the commits on Nov 4-9, 2020 --- espresso/data/asr_chain_dataset.py | 2 +- espresso/data/asr_dictionary.py | 8 +- espresso/data/encoders/characters_asr.py | 2 +- espresso/speech_train.py | 97 ++++++++++++++------- espresso/tasks/speech_recognition.py | 16 ++-- espresso/tasks/speech_recognition_hybrid.py | 8 +- 6 files changed, 85 insertions(+), 48 deletions(-) diff --git a/espresso/data/asr_chain_dataset.py b/espresso/data/asr_chain_dataset.py index 5b88580fc..b6f89fcf8 100644 --- a/espresso/data/asr_chain_dataset.py +++ b/espresso/data/asr_chain_dataset.py @@ -307,7 +307,7 @@ def collater(self, samples, pad_to_length=None): samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {"source": source_pad_to_length} - to indicate the max length to pad to in source and target respectively. + to indicate the max length to pad to in source. Returns: dict: a mini-batch with the following keys: diff --git a/espresso/data/asr_dictionary.py b/espresso/data/asr_dictionary.py index 4cdd44dac..53422347d 100644 --- a/espresso/data/asr_dictionary.py +++ b/espresso/data/asr_dictionary.py @@ -8,8 +8,8 @@ import torch from fairseq.data import Dictionary, encoders +from fairseq.dataclass import FairseqDataclass from fairseq.file_io import PathManager -from omegaconf import DictConfig # will automatically load modules defined from there from espresso.data import encoders as encoders_espresso @@ -99,12 +99,12 @@ def dummy_sentence(self, length): t[-1] = self.eos() return t - def build_tokenizer(self, cfg: Union[DictConfig, Namespace]): + def build_tokenizer(self, cfg: Union[FairseqDataclass, Namespace]): self.tokenizer = encoders.build_tokenizer(cfg) - def build_bpe(self, cfg: Union[DictConfig, Namespace]): + def build_bpe(self, cfg: Union[FairseqDataclass, Namespace]): if ( - (isinstance(cfg, DictConfig) and cfg._name == "characters_asr") + (isinstance(cfg, FairseqDataclass) and cfg._name == "characters_asr") or (isinstance(cfg, Namespace) and getattr(cfg, "bpe", None) == "characters_asr") ): self.bpe = encoders.build_bpe( diff --git a/espresso/data/encoders/characters_asr.py b/espresso/data/encoders/characters_asr.py index 0bd9a48d0..61360a7a9 100644 --- a/espresso/data/encoders/characters_asr.py +++ b/espresso/data/encoders/characters_asr.py @@ -20,7 +20,7 @@ class CharactersAsrConfig(FairseqDataclass): @register_bpe("characters_asr", dataclass=CharactersAsrConfig) class CharactersAsr(object): def __init__( - self, cfg, space_symbol="", ends_with_space=True, + self, cfg: CharactersAsrConfig, space_symbol="", ends_with_space=True, non_lang_syms: Optional[List[str]] = None, ): self.space_symbol = space_symbol diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 9391ba72e..ffe14b661 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -49,8 +49,9 @@ def main(cfg: DictConfig) -> None: utils.import_user_module(cfg.common) - assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ - "Must specify batch size either with --max-tokens or --batch-size" + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) @@ -71,19 +72,21 @@ def main(cfg: DictConfig) -> None: for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) + assert cfg.criterion, "Please specify criterion to train a model" + # Build model and criterion model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) - logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) + logger.info("task: {}".format(task.__class__.__name__)) + logger.info("model: {}".format(model.__class__.__name__)) + logger.info("criterion: {})".format(criterion.__class__.__name__)) logger.info( - "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) + "num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + ) ) - logger.info("num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - )) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: @@ -101,11 +104,17 @@ def main(cfg: DictConfig) -> None: else: trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info("training on {} devices (GPUs/TPUs)".format(cfg.distributed_training.distributed_world_size)) - logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( - cfg.dataset.max_tokens, - cfg.dataset.batch_size, - )) + logger.info( + "training on {} devices (GPUs/TPUs)".format( + cfg.distributed_training.distributed_world_size + ) + ) + logger.info( + "max tokens per GPU = {} and batch size per GPU = {}".format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + ) + ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator @@ -120,10 +129,7 @@ def main(cfg: DictConfig) -> None: lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while ( - lr > cfg.optimization.min_lr - and epoch_itr.next_epoch_idx <= max_epoch - ): + while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: @@ -161,14 +167,20 @@ def is_better(a, b): else: should_stop_early.num_runs += 1 if should_stop_early.num_runs >= cfg.checkpoint.patience: - logger.info("early stop since valid performance hasn't improved for last {} runs".format(cfg.checkpoint.patience)) + logger.info( + "early stop since valid performance hasn't improved for last {} runs".format( + cfg.checkpoint.patience + ) + ) return True else: return False @metrics.aggregate("train") -def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: +def train( + cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr +) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( @@ -189,7 +201,9 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( @@ -244,7 +258,14 @@ def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) return valid_losses, should_stop -def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: +def validate_and_save( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + valid_subsets: List[str], + end_of_epoch: bool, +) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf do_save = ( @@ -279,14 +300,17 @@ def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask or num_updates >= max_update or ( cfg.optimization.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours + and trainer.cumulative_training_time() / (60 * 60) + > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint( + cfg.checkpoint, trainer, epoch_itr, valid_losses[0] + ) return valid_losses, should_stop @@ -296,7 +320,13 @@ def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: return stats -def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: +def validate( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + subsets: List[str], +) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: @@ -306,7 +336,7 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: - logger.info("begin validation on '{}' subset".format(subset)) + logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) @@ -319,7 +349,9 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( @@ -341,13 +373,16 @@ def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_i return valid_losses -def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: +def get_valid_stats( + cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any] +) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, + stats[cfg.checkpoint.best_checkpoint_metric], ) return stats @@ -359,7 +394,9 @@ def print_options_meaning_changes(args): logger.info("--max-tokens is the maximum number of input frames in a batch") -def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: +def cli_main( + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None +) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) print_options_meaning_changes(args) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 58ed02b1c..94cdb9836 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -18,7 +18,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.logging import metrics from fairseq.tasks import FairseqTask, register_task -from omegaconf import II, DictConfig +from omegaconf import II from espresso.data import ( AsrDataset, @@ -233,7 +233,7 @@ def build_dictionary( """ raise NotImplementedError - def __init__(self, cfg: DictConfig, tgt_dict, word_dict=None): + def __init__(self, cfg: SpeechRecognitionEspressoConfig, tgt_dict, word_dict=None): super().__init__(cfg) self.tgt_dict = tgt_dict self.word_dict = word_dict @@ -246,11 +246,11 @@ def __init__(self, cfg: DictConfig, tgt_dict, word_dict=None): torch.rand(1) @classmethod - def setup_task(cls, cfg: DictConfig, **kwargs): + def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - cfg (omegaconf.DictConfig): parsed command-line arguments + cfg (SpeechRecognitionEspressoConfig): configuration of this task """ # load dictionaries dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict @@ -315,8 +315,8 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None) constraints=constraints, ) - def build_model(self, cfg: DictConfig): - model = super().build_model(cfg) + def build_model(self, model_cfg: FairseqDataclass): + model = super().build_model(model_cfg) # build the greedy decoder for validation with WER from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder @@ -353,13 +353,13 @@ def target_dictionary(self): """Return the target :class:`~fairseq.data.AsrDictionary`.""" return self.tgt_dict - def build_tokenizer(self, cfg: DictConfig): + def build_tokenizer(self, cfg: FairseqDataclass): """Build the pre-tokenizer for this task.""" self.tgt_dict.build_tokenizer(cfg) # the instance is built within self.tgt_dict return self.tgt_dict.tokenizer - def build_bpe(self, cfg: DictConfig): + def build_bpe(self, cfg: FairseqDataclass): """Build the tokenizer for this task.""" self.tgt_dict.build_bpe(cfg) # the instance is built within self.tgt_dict diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index a10b07ea5..e2b1b1449 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -18,7 +18,7 @@ from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import GenerationConfig from fairseq.tasks import FairseqTask, register_task -from omegaconf import II, DictConfig +from omegaconf import II from espresso.data import ( AliScpCachedDataset, @@ -339,7 +339,7 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding """ raise NotImplementedError - def __init__(self, cfg: DictConfig, dictionary): + def __init__(self, cfg: SpeechRecognitionHybridConfig, dictionary): super().__init__(cfg) self.dictionary = dictionary self.feat_in_channels = cfg.feat_in_channels @@ -372,11 +372,11 @@ def __init__(self, cfg: DictConfig, dictionary): torch.backends.cudnn.deterministic = True @classmethod - def setup_task(cls, cfg: DictConfig, **kwargs): + def setup_task(cls, cfg: SpeechRecognitionHybridConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - cfg (omegaconf.DictConfig): parsed command-line arguments + cfg (SpeechRecognitionHybridConfig): configuration of this task """ # load dictionaries dict_path = cfg.dict From 6c6ee415daaea4cff27dc142beb853c0c05588e9 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 11 Nov 2020 16:38:40 -0500 Subject: [PATCH 104/119] code adaptation/changes according to the commits on Nov 11, 2020; obtain feat_dim in setup_task() instead --- espresso/dump_posteriors.py | 7 +-- espresso/speech_recognize.py | 7 +-- espresso/tasks/speech_recognition.py | 52 +++++++++++++++------ espresso/tasks/speech_recognition_hybrid.py | 47 ++++++++++++++----- 4 files changed, 81 insertions(+), 32 deletions(-) diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index 214172fe4..62c02cd37 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -63,15 +63,13 @@ def _main(cfg, output_file): use_cuda = torch.cuda.is_available() and not cfg.common.cpu - # Load dataset split task = tasks.setup_task(cfg.task) - task.load_dataset(cfg.dataset.gen_subset) overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - models, _model_args = checkpoint_utils.load_model_ensemble( + models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, @@ -80,6 +78,9 @@ def _main(cfg, output_file): num_shards=cfg.checkpoint.checkpoint_shard_count, ) + # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) + # Load state prior for cross-entropy trained systems decoding if cfg.generation.state_prior_file is not None: prior = torch.from_numpy(kaldi_io.read_vec_flt(cfg.generation.state_prior_file)) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 5c982f464..684f47118 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -82,9 +82,7 @@ def _main(cfg, output_file): use_cuda = torch.cuda.is_available() and not cfg.common.cpu - # Load dataset split task = tasks.setup_task(cfg.task) - task.load_dataset(cfg.dataset.gen_subset) # Set dictionary dictionary = task.target_dictionary @@ -93,7 +91,7 @@ def _main(cfg, output_file): # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - models, _model_args = checkpoint_utils.load_model_ensemble( + models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, @@ -102,6 +100,9 @@ def _main(cfg, output_file): num_shards=cfg.checkpoint.checkpoint_shard_count, ) + # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) + if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 94cdb9836..31ebadbfd 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -85,6 +85,7 @@ class SpeechRecognitionEspressoConfig(FairseqDataclass): data_buffer_size: int = II("dataset.data_buffer_size") tpu: bool = II("common.tpu") train_subset: str = II("dataset.train_subset") + valid_subset: str = II("dataset.valid_subset") gen_subset: str = II("dataset.gen_subset") required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") @@ -94,7 +95,7 @@ def get_asr_dataset_from_json( split, tgt_dict, combine, - upsample_primary, + upsample_primary=1, num_buckets=0, shuffle=True, pad_to_multiple=1, @@ -233,10 +234,11 @@ def build_dictionary( """ raise NotImplementedError - def __init__(self, cfg: SpeechRecognitionEspressoConfig, tgt_dict, word_dict=None): + def __init__(self, cfg: SpeechRecognitionEspressoConfig, tgt_dict, feat_dim, word_dict=None): super().__init__(cfg) self.tgt_dict = tgt_dict self.word_dict = word_dict + self.feat_dim = feat_dim self.feat_in_channels = cfg.feat_in_channels self.specaugment_config = cfg.specaugment_config torch.backends.cudnn.deterministic = True @@ -256,19 +258,48 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs): dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict tgt_dict = cls.load_dictionary(dict_path, non_lang_syms=cfg.non_lang_syms) logger.info("dictionary: {} types".format(len(tgt_dict))) + + # minimum code for loading data in order to obtain feat_dim + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + data_path = paths[0] + split = cfg.valid_subset.split(",")[0] # valid set is usually much smaller than train set, so it's faster + try: + src_dataset = get_asr_dataset_from_json(data_path, split, tgt_dict, combine=False).src + except FileNotFoundError: + logger.warning(f"'{split}' set not found. Try to obtain feat_dim from '{cfg.gen_subset}'") + src_dataset = get_asr_dataset_from_json(data_path, cfg.gen_subset, tgt_dict, combined=False).src + if isinstance(src_dataset, ConcatDataset): + feat_dim = src_dataset.datasets[0].feat_dim + elif isinstance(src_dataset, BaseWrapperDataset): + feat_dim = src_dataset.dataset.feat_dim + else: + feat_dim = src_dataset.feat_dim + if cfg.word_dict is not None: word_dict = cls.load_dictionary(cfg.word_dict) logger.info("word dictionary: {} types".format(len(word_dict))) - return cls(cfg, tgt_dict, word_dict) + return cls(cfg, tgt_dict, feat_dim, word_dict=word_dict) else: - return cls(cfg, tgt_dict) - - def load_dataset(self, split, epoch=1, combine=False, **kwargs): + return cls(cfg, tgt_dict, feat_dim) + + def load_dataset( + self, + split: str, + epoch: int = 1, + combine: bool = False, + task_cfg: FairseqDataclass = None, + **kwargs, + ): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) + epoch (int): epoch number determining which shard of training data to load + combine (bool): combines a split segmented into pieces into one dataset + task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used + to load datasets """ paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 @@ -276,6 +307,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] + task_cfg = task_cfg or self.cfg self.datasets[split] = get_asr_dataset_from_json( data_path, @@ -290,14 +322,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): specaugment_config=self.specaugment_config, ) - src_dataset = self.datasets[split].src - if isinstance(src_dataset, ConcatDataset): - self.feat_dim = src_dataset.datasets[0].feat_dim - elif isinstance(src_dataset, BaseWrapperDataset): - self.feat_dim = src_dataset.dataset.feat_dim - else: - self.feat_dim = src_dataset.feat_dim - # update the counts of and in tgt_dict with training data if split == "train": tgt_dataset = self.datasets[split].tgt diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index e2b1b1449..58ffe0c6b 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -151,7 +151,7 @@ def get_asr_dataset_from_json( split, dictionary, combine, - upsample_primary, + upsample_primary=1, num_buckets=0, shuffle=True, pad_to_multiple=1, @@ -339,9 +339,10 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding """ raise NotImplementedError - def __init__(self, cfg: SpeechRecognitionHybridConfig, dictionary): + def __init__(self, cfg: SpeechRecognitionHybridConfig, dictionary, feat_dim): super().__init__(cfg) self.dictionary = dictionary + self.feat_dim = feat_dim self.feat_in_channels = cfg.feat_in_channels self.specaugment_config = cfg.specaugment_config self.num_targets = cfg.num_targets @@ -387,13 +388,42 @@ def setup_task(cls, cfg: SpeechRecognitionHybridConfig, **kwargs): ) if dictionary is not None: logger.info("dictionary: {} types".format(len(dictionary))) - return cls(cfg, dictionary) - def load_dataset(self, split, epoch=1, combine=False, **kwargs): + # minimum code for loading data in order to obtain feat_dim + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + data_path = paths[0] + split = cfg.valid_subset.split(",")[0] # valid set is usually much smaller than train set, so it's faster + try: + src_dataset = get_asr_dataset_from_json(data_path, split, dictionary, combine=False).src + except FileNotFoundError: + logger.warning(f"'{split}' set not found. Try to obtain feat_dim from '{cfg.gen_subset}'") + src_dataset = get_asr_dataset_from_json(data_path, cfg.gen_subset, dictionary, combined=False).src + if isinstance(src_dataset, ConcatDataset): + feat_dim = src_dataset.datasets[0].feat_dim + elif isinstance(src_dataset, BaseWrapperDataset): + feat_dim = src_dataset.dataset.feat_dim + else: + feat_dim = src_dataset.feat_dim + + return cls(cfg, dictionary, feat_dim) + + def load_dataset( + self, + split: str, + epoch: int = 1, + combine: bool = False, + task_cfg: FairseqDataclass = None, + **kwargs, + ): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) + epoch (int): epoch number determining which shard of training data to load + combine (bool): combines a split segmented into pieces into one dataset + task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used + to load datasets """ paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 @@ -401,6 +431,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): # if not training data set, use the first shard for valid and test paths = paths[:1] data_path = paths[(epoch - 1) % len(paths)] + task_cfg = task_cfg or self.cfg self.datasets[split] = get_asr_dataset_from_json( data_path, @@ -422,14 +453,6 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): label_delay=self.label_delay, ) - src_dataset = self.datasets[split].src - if isinstance(src_dataset, ConcatDataset): - self.feat_dim = src_dataset.datasets[0].feat_dim - elif isinstance(src_dataset, BaseWrapperDataset): - self.feat_dim = src_dataset.dataset.feat_dim - else: - self.feat_dim = src_dataset.feat_dim - def build_generator(self, models, cfg: GenerationConfig): if cfg.score_reference: cfg.score_reference = False From 2b68caf10740ac83cb6f4d0b7a4fa54410ef8d17 Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 14 Nov 2020 17:33:49 -0500 Subject: [PATCH 105/119] fix an error when more than one external LMs are used for shallow fusion --- espresso/models/speech_lstm.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 4764ffd89..aeca6bf60 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -677,7 +677,10 @@ def extract_features( encoder_padding_mask = torch.empty(0) srclen = encoder_outs.size(0) - if incremental_state is not None and len(incremental_state) > 0: + if ( + incremental_state is not None + and self._get_full_incremental_state_key("cached_state") in incremental_state + ): prev_output_tokens = prev_output_tokens[:, -1:] bsz, seqlen = prev_output_tokens.size() @@ -690,7 +693,10 @@ def extract_features( x = x.transpose(0, 1) # initialize previous states (or get from cache during incremental generation) - if incremental_state is not None and len(incremental_state) > 0: + if ( + incremental_state is not None + and self._get_full_incremental_state_key("cached_state") in incremental_state + ): prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) else: zero_state = x.new_zeros(bsz, self.hidden_size) @@ -813,7 +819,10 @@ def reorder_incremental_state( incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor, ): - if incremental_state is None or len(incremental_state) == 0: + if ( + incremental_state is None + or self._get_full_incremental_state_key("cached_state") not in incremental_state + ): return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens] @@ -832,7 +841,10 @@ def reorder_incremental_state( return def masked_copy_incremental_state(self, incremental_state, another_cached_state, mask): - if incremental_state is None or len(incremental_state) == 0: + if ( + incremental_state is None + or self._get_full_incremental_state_key("cached_state") not in incremental_state + ): assert another_cached_state is None or len(another_cached_state) == 0 return prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) From 72fa5968a3a6c06380a336114c876fc5c4e44caf Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 18 Nov 2020 16:57:26 -0500 Subject: [PATCH 106/119] code adaptation/changes according to the commits on Nov 16-20, 2020; fix a bug in Multi-level LM when getting cached states --- espresso/criterions/lf_mmi_loss.py | 28 ++--- espresso/models/external_language_model.py | 2 +- espresso/models/speech_lstm.py | 88 ++++++++------- espresso/models/speech_lstm_encoder_model.py | 23 ++-- espresso/models/speech_tdnn.py | 69 ++++++------ espresso/models/speech_transformer.py | 35 +++--- .../speech_transformer_encoder_model.py | 100 +++++++++--------- .../lr_scheduler/reduce_lr_on_plateau_v2.py | 59 ++--------- espresso/tasks/speech_recognition_hybrid.py | 2 +- .../tools/generate_log_probs_for_decoding.py | 9 +- espresso/tools/simple_greedy_decoder.py | 4 +- examples/asr_swbd/run.sh | 2 +- 12 files changed, 205 insertions(+), 216 deletions(-) diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index dd00e9188..a6c58990c 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -6,14 +6,15 @@ from dataclasses import dataclass, field import logging import math +from omegaconf import II import torch from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask from fairseq.logging import metrics -from omegaconf import II logger = logging.getLogger(__name__) @@ -131,11 +132,7 @@ def backward(ctx, objf_grad): @register_criterion("lattice_free_mmi", dataclass=LatticeFreeMMICriterionConfig) class LatticeFreeMMICriterion(FairseqCriterion): - - def __init__( - self, task, sentence_avg, denominator_fst_path, leaky_hmm_coefficient, - xent_regularization_coefficient, output_l2_regularization_coefficient, - ): + def __init__(self, cfg: LatticeFreeMMICriterionConfig, task: FairseqTask): super().__init__(task) try: from pychain.graph import ChainGraph @@ -146,12 +143,12 @@ def __init__( "after entering espresso/tools" ) - self.sentence_avg = sentence_avg - den_fst = simplefst.StdVectorFst.read(denominator_fst_path) + self.sentence_avg = cfg.sentence_avg + den_fst = simplefst.StdVectorFst.read(cfg.denominator_fst_path) self.den_graph = ChainGraph(den_fst, initial_mode="leaky", final_mode="ones") - self.leaky_hmm_coefficient = leaky_hmm_coefficient - self.xent_regularize = xent_regularization_coefficient - self.output_l2_regularize = output_l2_regularization_coefficient + self.leaky_hmm_coefficient = cfg.leaky_hmm_coefficient + self.xent_regularize = cfg.xent_regularization_coefficient + self.output_l2_regularize = cfg.output_l2_regularization_coefficient def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -183,8 +180,8 @@ def compute_loss(self, net_output, sample, reduce=True): except ImportError: raise ImportError("Please install OpenFST and PyChain by `make openfst pychain` after entering espresso/tools") - encoder_out = net_output.encoder_out.transpose(0, 1) # T x B x V -> B x T x V - out_lengths = net_output.src_lengths.long() # B + encoder_out = net_output["encoder_out"][0].transpose(0, 1) # T x B x V -> B x T x V + out_lengths = net_output["src_lengths"][0].long() # B den_graphs = ChainGraphBatch(self.den_graph, sample["nsentences"]) if self.xent_regularize > 0.0: den_objf = ChainFunction.apply(encoder_out, out_lengths, den_graphs, self.leaky_hmm_coefficient) @@ -202,7 +199,10 @@ def compute_loss(self, net_output, sample, reduce=True): nll_loss = loss.clone().detach() if self.output_l2_regularize > 0.0: - encoder_padding_mask = net_output.encoder_padding_mask + encoder_padding_mask = ( + net_output["encoder_padding_mask"][0] if len(net_output["encoder_padding_mask"]) > 0 + else None + ) encoder_out_squared = encoder_out.pow(2.0) if encoder_padding_mask is not None: pad_mask = encoder_padding_mask.transpose(0, 1).unsqueeze(-1) # T x B -> B x T x 1 diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index c1f04307f..a0ca4af60 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -375,7 +375,7 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): node.word_idx if node is not None and node.word_idx >= 0 else self.word_unk_idx for node in nodes ]).unsqueeze(-1) # B x 1 - old_wordlm_cached_state = _clone_cached_state(wordlm_cached_state) + old_wordlm_cached_state = _clone_cached_state(self.wordlm_decoder.get_cached_state(incremental_state)) # recompute wordlm_logprobs from inter-word transition probabilities # only for those whose prev_output_token is diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index aeca6bf60..0936aee36 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -20,7 +20,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import ( Embedding, LSTM, @@ -462,37 +461,44 @@ def forward( encoder_padding_mask = padding_mask.t() - return EncoderOut( - encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() - else None, # T x B - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=src_lengths, # B - ) + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() + else [], # T x B + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [src_lengths], # B + } + + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(1, new_order) # note: transposed + ] + if len(encoder_out["src_lengths"]) == 0: + new_src_lengths = [] + else: + new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] - def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): - encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask - src_lengths: Optional[Tensor] = encoder_out.src_lengths - new_encoder_padding_mask = ( - encoder_padding_mask - if encoder_padding_mask is None - else encoder_padding_mask.index_select(1, new_order) - ) - new_src_lengths = ( - src_lengths - if src_lengths is None - else src_lengths.index_select(0, new_order) - ) - return EncoderOut( - encoder_out=encoder_out.encoder_out.index_select(1, new_order), - encoder_padding_mask=new_encoder_padding_mask, - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=new_src_lengths, - ) + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": new_src_lengths, # B x 1 + } def max_positions(self): """Maximum input length supported by the encoder.""" @@ -592,7 +598,7 @@ def __init__( def forward( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, **kwargs, ): @@ -600,7 +606,7 @@ def forward( Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing - encoder_out (EncoderOut, optional): output from the encoder, used for + encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` @@ -628,7 +634,7 @@ def _forward_with_scheduled_sampling( self, prev_output_tokens, sampling_prob, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, ): bsz, seqlen = prev_output_tokens.size() @@ -655,7 +661,7 @@ def _forward_with_scheduled_sampling( def extract_features( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, **unused, ): @@ -670,8 +676,16 @@ def extract_features( # get outputs from encoder if encoder_out is not None: assert self.attention is not None - encoder_outs = encoder_out.encoder_out - encoder_padding_mask = encoder_out.encoder_padding_mask + encoder_outs = ( + encoder_out["encoder_out"][0] + if len(encoder_out["encoder_out"]) > 0 + else torch.empty(0) + ) + encoder_padding_mask = ( + encoder_out["encoder_padding_mask"][0] + if len(encoder_out["encoder_padding_mask"]) > 0 + else None + ) else: encoder_outs = torch.empty(0) encoder_padding_mask = torch.empty(0) diff --git a/espresso/models/speech_lstm_encoder_model.py b/espresso/models/speech_lstm_encoder_model.py index 565baa2cd..c024d1ef6 100644 --- a/espresso/models/speech_lstm_encoder_model.py +++ b/espresso/models/speech_lstm_encoder_model.py @@ -17,7 +17,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import Linear from omegaconf import DictConfig @@ -254,15 +253,19 @@ def forward( if self.fc_out is not None: x = self.fc_out(x) # T x B x C -> T x B x V - return EncoderOut( - encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() - else None, # T x B - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=x_lengths, # B - ) + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() + else [], # T x B + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [x_lengths], # B + } @register_model_architecture("speech_lstm_encoder_model", "speech_lstm_encoder_model") diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index 214a4d0a7..7443a3b65 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -5,7 +5,7 @@ from argparse import Namespace import logging -from typing import Optional +from typing import Dict, List, Optional import torch from torch import Tensor @@ -19,7 +19,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.lstm import Linear from fairseq.modules import FairseqDropout from omegaconf import DictConfig @@ -259,14 +258,19 @@ def forward(self, src_tokens, src_lengths: Tensor, **unused): assert not encoder_padding_mask.any() x = self.output_layer(x) - return EncoderOut( - encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask if encoder_padding_mask.any() else None, # T x B - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=x_lengths, # B - ) + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() + else [], # T x B + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [x_lengths], # B + } def extract_features(self, src_tokens, src_lengths, **unused): x, x_lengths = src_tokens, src_lengths @@ -289,27 +293,30 @@ def output_layer(self, features, **kwargs): """Project features to the vocabulary size.""" return self.fc_out(features) # T x B x C -> T x B x V - def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): - encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask - src_lengths: Optional[Tensor] = encoder_out.src_lengths - new_encoder_padding_mask = ( - encoder_padding_mask - if encoder_padding_mask is None - else encoder_padding_mask.index_select(1, new_order) - ) - new_src_lengths = ( - src_lengths - if src_lengths is None - else src_lengths.index_select(0, new_order) - ) - return EncoderOut( - encoder_out=encoder_out.encoder_out.index_select(1, new_order), - encoder_padding_mask=new_encoder_padding_mask, - encoder_embedding=None, - encoder_states=None, - src_tokens=None, - src_lengths=new_src_lengths, - ) + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(1, new_order) # note: transposed + ] + if len(encoder_out["src_lengths"]) == 0: + new_src_lengths = [] + else: + new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": new_src_lengths, # B x 1 + } def max_positions(self): """Maximum input length supported by the encoder.""" diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index a381b1a77..f7a9ac6e4 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch from torch import Tensor @@ -15,7 +15,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import ( Linear, TransformerModel, @@ -384,15 +383,12 @@ def forward( if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) - if not encoder_padding_mask.any(): - encoder_padding_mask = None - # B x T x C -> T x B x C x = x.transpose(0, 1) attn_mask = self.get_attn_mask(src_lengths) - encoder_states = [] if return_all_hiddens else None + encoder_states = [] # encoder layers for layer in self.layers: @@ -404,14 +400,19 @@ def forward( if self.layer_norm is not None: x = self.layer_norm(x) - return EncoderOut( - encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask, # B x T - encoder_embedding=None, - encoder_states=encoder_states, # List[T x B x C] - src_tokens=None, - src_lengths=None, - ) + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() + else [], # B x T + "encoder_embedding": [], + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } def max_positions(self): """Maximum input length supported by the encoder.""" @@ -433,7 +434,7 @@ def __init__( def forward( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, @@ -447,7 +448,7 @@ def forward( Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing - encoder_out (EncoderOut, optional): output from the encoder, used for + encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` @@ -494,7 +495,7 @@ def _forward_with_scheduled_sampling( self, prev_output_tokens, sampling_prob, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index 9fccf6813..1e78f86ca 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -5,7 +5,7 @@ from argparse import Namespace import logging -from typing import Optional +from typing import Dict, List, Optional import torch from torch import Tensor @@ -17,7 +17,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.transformer import Linear from omegaconf import DictConfig @@ -278,7 +277,7 @@ def forward( Only populated if *return_all_hiddens* is True. """ out = super().forward(src_tokens, src_lengths, return_all_hiddens=return_all_hiddens) - x, x_lengths = out.encoder_out, out.src_lengths + x, x_lengths = out["encoder_out"][0], out["src_lengths"][0] # determine which output frame to select for loss evaluation/test, assuming # all examples in a batch are of the same length for chunk-wise training/test @@ -292,17 +291,21 @@ def forward( if self.fc_out is not None: x = self.fc_out(x) # T x B x C -> T x B x V - return EncoderOut( - encoder_out=x, # T x B x C - encoder_padding_mask=out.encoder_padding_mask.transpose(0, 1), # T x B - encoder_embedding=out.encoder_embedding, # None - encoder_states=out.encoder_states, # List[T x B x C] - src_tokens=out.src_tokens, # None - src_lengths=x_lengths, # B - ) + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [out["encoder_padding_mask"][0].transpose(0, 1)], # T x B + "encoder_embedding": out["encoder_embedding"], # None + "encoder_states": out["encoder_states"], # List[T x B x C] + "src_tokens": out["src_tokens"], # None + "src_lengths": [x_lengths], # B + } @torch.jit.export - def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. @@ -313,50 +316,45 @@ def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): Returns: *encoder_out* rearranged according to *new_order* """ - """ - Since encoder_padding_mask and encoder_embedding are both of type - Optional[Tensor] in EncoderOut, they need to be copied as local - variables for Torchscript Optional refinement - """ - encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask - encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding - - new_encoder_out = ( - encoder_out.encoder_out - if encoder_out.encoder_out is None - else encoder_out.encoder_out.index_select(1, new_order) - ) - new_encoder_padding_mask = ( - encoder_padding_mask - if encoder_padding_mask is None - else encoder_padding_mask.index_select(1, new_order) # note: transposed - ) - new_encoder_embedding = ( - encoder_embedding - if encoder_embedding is None - else encoder_embedding.index_select(0, new_order) - ) - src_tokens = encoder_out.src_tokens - if src_tokens is not None: - src_tokens = src_tokens.index_select(0, new_order) + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(1, new_order) # note: transposed + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + encoder_out["encoder_embedding"][0].index_select(0, new_order) + ] + if len(encoder_out["src_tokens"]) == 0: + new_src_tokens = [] + else: + new_src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] - src_lengths = encoder_out.src_lengths - if src_lengths is not None: - src_lengths = src_lengths.index_select(0, new_order) + if len(encoder_out["src_lengths"]) == 0: + new_src_lengths = [] + else: + new_src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] - encoder_states = encoder_out.encoder_states - if encoder_states is not None: + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) - return EncoderOut( - encoder_out=new_encoder_out, # T x B x C - encoder_padding_mask=new_encoder_padding_mask, # B x T - encoder_embedding=new_encoder_embedding, # B x T x C - encoder_states=encoder_states, # List[T x B x C] - src_tokens=src_tokens, # B x T - src_lengths=src_lengths, # B x 1 - ) + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": new_src_tokens, # B x T + "src_lengths": new_src_lengths, # B x 1 + } @register_model_architecture("speech_transformer_encoder_model", "speech_transformer_encoder_model") diff --git a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py index c3f1141db..0f60eb284 100644 --- a/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py +++ b/espresso/optim/lr_scheduler/reduce_lr_on_plateau_v2.py @@ -8,56 +8,27 @@ import torch.optim.lr_scheduler -from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.lr_scheduler import register_lr_scheduler -from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ReduceLROnPlateau -from omegaconf import II, DictConfig +from fairseq.optim.lr_scheduler.reduce_lr_on_plateau import ( + ReduceLROnPlateauLRSchedule, + ReduceLROnPlateauLRScheduleConfig, +) @dataclass -class ReduceLROnPlateauV2Config(FairseqDataclass): - lr_shrink: float = field( - default=0.1, - metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"}, - ) - lr_threshold: float = field( - default=1e-4, - metadata={ - "help": "threshold for measuring the new optimum, to only focus on significant changes" - }, - ) - lr_patience: int = field( - default=0, - metadata={ - "help": "number of epochs with no improvement after which learning rate will be reduced" - }, - ) - warmup_updates: int = field( +class ReduceLROnPlateauLRScheduleV2Config(ReduceLROnPlateauLRScheduleConfig): + start_reduce_lr_epoch: int = field( default=0, - metadata={"help": "warmup the learning rate linearly for the first N updates"}, - ) - warmup_init_lr: float = field( - default=-1, - metadata={ - "help": "initial learning rate during warmup phase; default is cfg.lr" - }, + metadata={"help": "start to reduce lr from the specified epoch"}, ) final_lr_scale: float = field( default=0.01, metadata={"help": "final learning rate scale; default to 0.01"}, ) - start_reduce_lr_epoch: int = field( - default=0, - metadata={"help": "start to reduce lr from the specified epoch"}, - ) - # TODO common vars at parent class - lr: List[float] = II("optimization.lr") - maximize_best_checkpoint_metric: bool = II("checkpoint.maximize_best_checkpoint_metric") -@register_lr_scheduler("reduce_lr_on_plateau_v2", dataclass=ReduceLROnPlateauV2Config) -class ReduceLROnPlateauV2(ReduceLROnPlateau): +@register_lr_scheduler("reduce_lr_on_plateau_v2", dataclass=ReduceLROnPlateauLRScheduleV2Config) +class ReduceLROnPlateauLRScheduleV2(ReduceLROnPlateauLRSchedule): """Decay the LR by a factor every time the validation loss plateaus, starting from the epoch specified as cfg.start_reduce_lr_epoch. @@ -65,10 +36,8 @@ class ReduceLROnPlateauV2(ReduceLROnPlateau): of epochs is reached. """ - def __init__(self, cfg: DictConfig, fairseq_optimizer): - super().__init__(cfg, fairseq_optimizer) - - self.cfg = cfg + def __init__(self, cfg: ReduceLROnPlateauLRScheduleV2Config, optimizer): + super().__init__(cfg, optimizer) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, patience=cfg.lr_patience, @@ -78,12 +47,6 @@ def __init__(self, cfg: DictConfig, fairseq_optimizer): min_lr=cfg.final_lr_scale * cfg.lr[0], ) - @classmethod - def add_args(cls, parser): - dc = getattr(cls, "__dataclass", None) - if dc is not None: - gen_parser_from_dataclass(parser, dc()) - def step(self, epoch, val_loss=None): if epoch < self.cfg.start_reduce_lr_epoch: self.lr_scheduler.last_epoch = epoch diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index 58ffe0c6b..26e470ef0 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -415,7 +415,7 @@ def load_dataset( combine: bool = False, task_cfg: FairseqDataclass = None, **kwargs, - ): + ): """Load a given dataset split. Args: diff --git a/espresso/tools/generate_log_probs_for_decoding.py b/espresso/tools/generate_log_probs_for_decoding.py index c9eedebd0..a38112a3b 100644 --- a/espresso/tools/generate_log_probs_for_decoding.py +++ b/espresso/tools/generate_log_probs_for_decoding.py @@ -53,10 +53,13 @@ def _generate(self, sample: Dict[str, Dict[str, Tensor]], **kwargs): # compute the encoder output encoder_outs = self.model.forward_encoder(net_input) - logits = encoder_outs[0].encoder_out.transpose(0, 1).float() # T x B x V -> B x T x V + logits = encoder_outs[0]["encoder_out"][0].transpose(0, 1).float() # T x B x V -> B x T x V assert logits.size(0) == bsz - padding_mask = encoder_outs[0].encoder_padding_mask.t() \ - if encoder_outs[0].encoder_padding_mask is not None else None + padding_mask = ( + encoder_outs[0]["encoder_padding_mask"][0].t() + if len(encoder_outs[0]["encoder_padding_mask"]) > 0 + else None + ) if self.apply_log_softmax: return F.log_softmax(logits, dim=-1), padding_mask return logits, padding_mask diff --git a/espresso/tools/simple_greedy_decoder.py b/espresso/tools/simple_greedy_decoder.py index a34bee5f6..e1ef37eec 100644 --- a/espresso/tools/simple_greedy_decoder.py +++ b/espresso/tools/simple_greedy_decoder.py @@ -90,7 +90,7 @@ def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] target = sample["target"] # target can only be None if not for validation assert target is not None or not self.for_validation - max_encoder_output_length = encoder_outs[0].encoder_out.size(0) + max_encoder_output_length = encoder_outs[0]["encoder_out"][0].size(0) # for validation, make the maximum decoding length equal to at least the # length of target, and the length of encoder_out if possible; otherwise # max_len is obtained from max_len_a/b @@ -105,7 +105,7 @@ def _decode(self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] tokens = src_tokens.new(bsz, max_len + 2).long().fill_(self.pad) tokens[:, 0] = self.eos if bos_token is None else bos_token # lprobs is only used when target is not None (i.e., for validation) - lprobs = encoder_outs[0].encoder_out.new_full( + lprobs = encoder_outs[0]["encoder_out"][0].new_full( (bsz, target.size(1), self.vocab_size), -np.log(self.vocab_size), ) if self.for_validation else None attn = None diff --git a/examples/asr_swbd/run.sh b/examples/asr_swbd/run.sh index f3be25a98..e82225dba 100755 --- a/examples/asr_swbd/run.sh +++ b/examples/asr_swbd/run.sh @@ -34,7 +34,7 @@ fisher_dirs= if [[ $(hostname -f) == *.clsp.jhu.edu ]]; then swbd1_dir=/export/corpora3/LDC/LDC97S62 - eval2000_dir="/export/corpora2/LDC/LDC2002S09/hub5e_00 /export/corpora2/LDC/LDC2002T43" + eval2000_dir="/export/corpora3/LDC/LDC2002S09/hub5e_00 /export/corpora3/LDC/LDC2002T43" rt03_dir=/export/corpora/LDC/LDC2007S10 fisher_dirs="/export/corpora3/LDC/LDC2004T19/fe_03_p1_tran/ /export/corpora3/LDC/LDC2005T19/fe_03_p2_tran/" fi From 6b4e5717b382ebe474f4c637dff23dcad324ee9f Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 4 Dec 2020 15:52:07 -0500 Subject: [PATCH 107/119] fix length tensor device issue in lf_mmi loss; code adaptation/changes according to the commits on Dec 3-12, 2020 --- espresso/criterions/lf_mmi_loss.py | 4 +-- espresso/models/speech_transformer.py | 2 +- .../speech_transformer_encoder_model.py | 2 +- espresso/speech_recognize.py | 13 ++++---- espresso/speech_train.py | 31 ++++++++++++++++--- examples/asr_wsj/run.sh | 2 +- 6 files changed, 37 insertions(+), 17 deletions(-) diff --git a/espresso/criterions/lf_mmi_loss.py b/espresso/criterions/lf_mmi_loss.py index a6c58990c..9a3de34fa 100644 --- a/espresso/criterions/lf_mmi_loss.py +++ b/espresso/criterions/lf_mmi_loss.py @@ -51,7 +51,7 @@ def forward(ctx, input, input_lengths, num_graphs, den_graphs, leaky_coefficient "after entering espresso/tools" ) - input = input.clamp(-30, 30) # clamp for both the denominator and the numerator + input = input.contiguous().clamp(-30, 30) # clamp for both the denominator and the numerator B = input.size(0) if B != num_graphs.batch_size or B != den_graphs.batch_size: raise ValueError( @@ -60,7 +60,7 @@ def forward(ctx, input, input_lengths, num_graphs, den_graphs, leaky_coefficient .format(B, num_graphs.batch_size, den_graphs.batch_size) ) packed_data = torch.nn.utils.rnn.pack_padded_sequence( - input, input_lengths, batch_first=True, + input, input_lengths.cpu(), batch_first=True, ) batch_sizes = packed_data.batch_sizes input_lengths = input_lengths.cpu() diff --git a/espresso/models/speech_transformer.py b/espresso/models/speech_transformer.py index f7a9ac6e4..9ebe36c07 100644 --- a/espresso/models/speech_transformer.py +++ b/espresso/models/speech_transformer.py @@ -349,7 +349,7 @@ def forward( intermediate hidden states (default: False). Returns: - namedtuple: + dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of diff --git a/espresso/models/speech_transformer_encoder_model.py b/espresso/models/speech_transformer_encoder_model.py index 1e78f86ca..ec568d5c5 100644 --- a/espresso/models/speech_transformer_encoder_model.py +++ b/espresso/models/speech_transformer_encoder_model.py @@ -265,7 +265,7 @@ def forward( intermediate hidden states (default: False). Returns: - namedtuple: + dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 684f47118..26a2eb6f8 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -194,12 +194,11 @@ def _main(cfg, output_file): "eos_factor": cfg.generation.eos_factor, } cfg.generation.score_reference = False # not applicable for ASR - temp_val = cfg.generation.print_alignment - cfg.generation.print_alignment = False # not applicable for ASR + save_attention_plot = cfg.generation.print_alignment is not None + cfg.generation.print_alignment = None # not applicable for ASR generator = task.build_generator( models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) - cfg.generation.print_alignment = temp_val # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) @@ -242,7 +241,7 @@ def decode_fn(x): gen_timer.stop(num_generated_tokens) # obtain nonpad mask of encoder output to plot attentions - if cfg.generation.print_alignment: + if save_attention_plot: net_input = sample["net_input"] src_tokens = net_input["src_tokens"] output_lengths = models[0].encoder.output_lengths(net_input["src_lengths"]) @@ -275,8 +274,8 @@ def decode_fn(x): if j == 0: # src_len x tgt_len attention = hypo["attention"][nonpad_idxs[i]].float().cpu() \ - if cfg.generation.print_alignment and hypo["attention"] is not None else None - if cfg.generation.print_alignment and attention is not None: + if save_attention_plot and hypo["attention"] is not None else None + if save_attention_plot and attention is not None: save_dir = os.path.join(cfg.common_eval.results_path, "attn_plots") os.makedirs(save_dir, exist_ok=True) plot_attention(attention, detok_hypo_str, utt_id, save_dir) @@ -291,7 +290,7 @@ def decode_fn(x): logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) - if cfg.generation.print_alignment: + if save_attention_plot: logger.info("Saved attention plots in " + save_dir) if has_target: diff --git a/espresso/speech_train.py b/espresso/speech_train.py index ffe14b661..35559d083 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -80,7 +80,7 @@ def main(cfg: DictConfig) -> None: logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) - logger.info("criterion: {})".format(criterion.__class__.__name__)) + logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), @@ -129,7 +129,15 @@ def main(cfg: DictConfig) -> None: lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while epoch_itr.next_epoch_idx <= max_epoch: + if lr <= cfg.optimization.stop_min_lr: + logger.info( + f"stopping training because current learning rate ({lr}) is smaller " + "than or equal to minimum learning rate " + f"(--stop-min-lr={cfg.optimization.stop_min_lr})" + ) + break + # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: @@ -193,7 +201,7 @@ def train( else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(cfg.common, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, @@ -207,7 +215,15 @@ def train( ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( - cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + ), + azureml_logging=( + cfg.common.azureml_logging if distributed_utils.is_master(cfg.distributed_training) else False ), ) @@ -355,7 +371,12 @@ def validate( ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( - cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) diff --git a/examples/asr_wsj/run.sh b/examples/asr_wsj/run.sh index c0bba6d61..15df5ef30 100755 --- a/examples/asr_wsj/run.sh +++ b/examples/asr_wsj/run.sh @@ -320,7 +320,7 @@ if [ ${stage} -le 10 ]; then --num-shards 1 --shard-id 0 --dict $dict --bpe characters_asr --non-lang-syms $nlsyms \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 50 --max-len-a 0.2 --max-len-b 0 --lenpen 1.0 \ - --results-path $decode_dir $opts --print-alignment + --results-path $decode_dir $opts --print-alignment hard echo "log saved in ${decode_dir}/decode.log" if $kaldi_scoring; then From 1f812eb3a9f7e99a309bc83002fd65496a2ffe6f Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 23 Dec 2020 03:53:50 -0500 Subject: [PATCH 108/119] code adaptation/changes according to the commits on Dec 22, 2020 --- espresso/speech_train.py | 46 ++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/espresso/speech_train.py b/espresso/speech_train.py index 35559d083..8bd45cd10 100755 --- a/espresso/speech_train.py +++ b/espresso/speech_train.py @@ -17,7 +17,6 @@ import numpy as np import torch - from fairseq import ( checkpoint_utils, distributed_utils, @@ -30,8 +29,8 @@ from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer -from omegaconf import DictConfig from fairseq.trainer import Trainer +from omegaconf import DictConfig logging.basicConfig( @@ -223,7 +222,9 @@ def train( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), azureml_logging=( - cfg.common.azureml_logging if distributed_utils.is_master(cfg.distributed_training) else False + cfg.common.azureml_logging + if distributed_utils.is_master(cfg.distributed_training) + else False ), ) @@ -284,9 +285,32 @@ def validate_and_save( ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf + + # Stopping conditions (and an additional one based on validation loss later + # on) + should_stop = False + if num_updates >= max_update: + should_stop = True + logger.info( + f"Stopping training due to " + f"num_updates: {num_updates} >= max_update: {max_update}" + ) + + training_time_hours = trainer.cumulative_training_time() / (60 * 60) + if ( + cfg.optimization.stop_time_hours > 0 + and training_time_hours > cfg.optimization.stop_time_hours + ): + should_stop = True + logger.info( + f"Stopping training due to " + f"cumulative_training_time: {training_time_hours} > " + f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" + ) + do_save = ( (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) - or num_updates >= max_update + or should_stop or ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 @@ -297,7 +321,7 @@ def validate_and_save( do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) - or num_updates >= max_update + or should_stop or ( cfg.dataset.validate_interval_updates > 0 and num_updates > 0 @@ -310,20 +334,10 @@ def validate_and_save( if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) - # Stopping conditions - should_stop = ( - should_stop_early(cfg, valid_losses[0]) - or num_updates >= max_update - or ( - cfg.optimization.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) - > cfg.optimization.stop_time_hours - ) - ) + should_stop |= should_stop_early(cfg, valid_losses[0]) # Save checkpoint if do_save or should_stop: - logger.info("begin save checkpoint") checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] ) From b7b893797eeef76c1f8f179b201d639dde4358ed Mon Sep 17 00:00:00 2001 From: freewym Date: Tue, 3 Nov 2020 21:45:34 -0500 Subject: [PATCH 109/119] Lhotse/K2 support --- espresso/data/__init__.py | 2 + espresso/data/asr_k2_dataset.py | 258 ++++++++++++++++++++ espresso/tasks/speech_recognition_hybrid.py | 30 +++ espresso/tools/Makefile | 8 +- 4 files changed, 297 insertions(+), 1 deletion(-) create mode 100644 espresso/data/asr_k2_dataset.py diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py index 152210569..8edd372e5 100644 --- a/espresso/data/__init__.py +++ b/espresso/data/__init__.py @@ -6,6 +6,7 @@ from .asr_bucket_pad_length_dataset import FeatBucketPadLengthDataset, TextBucketPadLengthDataset from .asr_chain_dataset import AsrChainDataset, NumeratorGraphDataset from .asr_dataset import AsrDataset +from .asr_k2_dataset import AsrK2Dataset from .asr_dictionary import AsrDictionary from .asr_xent_dataset import AliScpCachedDataset, AsrXentDataset from .feat_text_dataset import ( @@ -20,6 +21,7 @@ "AsrChainDataset", "AsrDataset", "AsrDictionary", + "AsrK2Dataset", "AsrTextDataset", "AsrXentDataset", "FeatBucketPadLengthDataset", diff --git a/espresso/data/asr_k2_dataset.py b/espresso/data/asr_k2_dataset.py new file mode 100644 index 000000000..0662016b4 --- /dev/null +++ b/espresso/data/asr_k2_dataset.py @@ -0,0 +1,258 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import re +from typing import Dict, List + +import numpy as np + +import torch + +from fairseq.data import FairseqDataset, data_utils + +import espresso.tools.utils as speech_utils +try: + # TODO use pip install once it's available + from espresso.tools.lhotse.cut import CutSet +except ImportError: + raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + + +def collate(samples, pad_to_length=None, pad_to_multiple=1): + if len(samples) == 0: + return {} + + def merge(key, pad_to_length=None): + if key == "source": + return speech_utils.collate_frames( + [sample[key] for sample in samples], 0.0, + pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, + ) + else: + raise ValueError("Invalid key.") + + id = torch.LongTensor([sample["id"] for sample in samples]) + src_frames = merge( + "source", + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + if pad_to_length is not None: + src_lengths = torch.IntTensor( + [sample["source"].ne(0.0).any(dim=1).int().sum() for sample in samples] + ) + else: + src_lengths = torch.IntTensor([s["source"].size(0) for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()] + src_frames = src_frames.index_select(0, sort_order) + ntokens = src_lengths.sum().item() + + target = None + if samples[0].get("target", None) is not None and len(samples[0].target) > 0: + # reorder the list of samples to make things easier + # (no need to reorder every element in target) + samples = [samples[i] for i in sort_order.numpy()] + + from torch.utils.data._utils.collate import default_collate + + dataset_idx_to_batch_idx = { + sample["target"][0]["sequence_idx"]: batch_idx + for batch_idx, sample in enumerate(samples) + } + + def update(d: Dict, **kwargs) -> Dict: + for key, value in kwargs.items(): + d[key] = value + return d + + target = default_collate([ + update(sup, sequence_idx=dataset_idx_to_batch_idx[sup["sequence_idx"]]) + for sample in samples + for sup in sample["target"] + ]) + + batch = { + "id": id, + "utt_id": utt_id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_frames, + "src_lengths": src_lengths, + }, + "target": target, + } + return batch + + +class AsrK2Dataset(FairseqDataset): + """ + A K2 Dataset for ASR. + + Args: + cuts (lhotse.CutSet): Lhotse CutSet to wrap + shuffle (bool, optional): shuffle dataset elements before batching + (default: True). + pad_to_multiple (int, optional): pad src lengths to a multiple of this value + """ + + def __init__( + self, + cuts: CutSet, + shuffle=True, + pad_to_multiple=1, + ): + self.cuts = cuts + self.cut_ids = list(self.cuts.ids) + self.src_sizes = np.array( + [cut.num_frames if cut.has_features else cut.num_samples for cut in cuts] + ) + self.tgt_sizes = None + first_cut = cuts[self.cut_ids[0]] + # assume all cuts have no supervisions if the first one does not + if len(first_cut.supervisions) > 0: + assert len(first_cut.supervisions) == 1, "Only single-supervision cuts are allowed" + assert first_cut.frame_shift is not None, "features are not available in cuts" + self.tgt_sizes = np.array( + [ + round( + cut.supervisions[0].trim(cut.duration).duration / cut.frame_shift + ) for cut in cuts + ] + ) + self.shuffle = shuffle + self.epoch = 1 + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) + self.pad_to_multiple = pad_to_multiple + self.feat_dim = self.cuts[self.cut_ids[0]].num_features + + def __getitem__(self, index): + cut_id = self.cut_ids[index] + cut = self.cuts[cut_id] + features = torch.from_numpy(cut.load_features()) + + example = { + "id": index, + "utt_id": cut_id, + "source": features, + "target": [ + { + "sequence_idx": index, + "text": sup.text, + "start_frame": round(sup.start / cut.frame_shift), + "num_frames": round(sup.duration / cut.frame_shift), + } + # CutSet's supervisions can exceed the cut, when the cut starts/ends in the middle + # of a supervision (they would have relative times e.g. -2 seconds start, meaning + # it started 2 seconds before the Cut starts). We use s.trim() to get rid of that + # property, ensuring the supervision time span does not exceed that of the cut. + for sup in (s.trim(cut.duration) for s in cut.supervisions) + ] + } + return example + + def __len__(self): + return len(self.cuts) + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {"source": source_pad_to_length} + to indicate the max length to pad to in source and target respectively. + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `utt_id` (List[str]): list of utterance ids + - `nsentences` (int): batch size + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (FloatTensor): a padded 3D Tensor of features in + the source of shape `(bsz, src_len, feat_dim)`. + - `src_lengths` (IntTensor): 1D Tensor of the unpadded + lengths of each source sequence of shape `(bsz)` + + - `target` (List[Dict[str, Any]]): an List representing a batch of + supervisions + """ + return collate( + samples, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, + ) + + def num_tokens(self, index): + """Return the number of frames in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.src_sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)).astype(np.int64) + else: + indices = np.arange(len(self), dtype=np.int64) + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + + @property + def supports_prefetch(self): + return False + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return False + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False # to avoid running out of CPU RAM + + def set_epoch(self, epoch): + super().set_epoch(epoch) + self.epoch = epoch diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index 26e470ef0..c8600e263 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -23,6 +23,7 @@ from espresso.data import ( AliScpCachedDataset, AsrChainDataset, + AsrK2Dataset, AsrXentDataset, AsrDictionary, AsrTextDataset, @@ -74,6 +75,7 @@ class SpeechRecognitionHybridConfig(FairseqDataclass): }, ) feat_in_channels: int = field(default=1, metadata={"help": "feature input channels"}) + use_k2_dataset: bool = field(default=False, metadata={"help": "if True use K2 dataset"}) specaugment_config: Optional[str] = field( default=None, metadata={ @@ -146,6 +148,22 @@ class SpeechRecognitionHybridConfig(FairseqDataclass): max_epoch: int = II("optimization.max_epoch") # to determine whether in trainig stage +def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1): + try: + # TODO use pip install once it's available + from espresso.tools.lhotse.cut import CutSet + except ImportError: + raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + + data_json_path = os.path.join(data_path, "cuts_{}.json".format(split)) + if not os.path.isfile(data_json_path): + raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + + cut_set = CutSet.from_json(data_json_path) + logger.info("{} {} examples".format(data_json_path, len(cut_set))) + return AsrK2Dataset(cut_set, shuffle=shuffle, pad_to_multiple=pad_to_multiple) + + def get_asr_dataset_from_json( data_path, split, @@ -344,6 +362,7 @@ def __init__(self, cfg: SpeechRecognitionHybridConfig, dictionary, feat_dim): self.dictionary = dictionary self.feat_dim = feat_dim self.feat_in_channels = cfg.feat_in_channels + self.use_k2_dataset = cfg.use_k2_dataset self.specaugment_config = cfg.specaugment_config self.num_targets = cfg.num_targets self.training_stage = (cfg.max_epoch > 0) # a hack @@ -433,6 +452,17 @@ def load_dataset( data_path = paths[(epoch - 1) % len(paths)] task_cfg = task_cfg or self.cfg + if self.use_k2_dataset: + self.datasets[split] = get_k2_dataset_from_json( + data_path, + split, + shuffle=(split != self.cfg.gen_subset), + pad_to_multiple=self.cfg.required_seq_len_multiple, + seed=self.cfg.seed, + ) + self.feat_dim = self.datasets[split].feat_dim + return + self.datasets[split] = get_asr_dataset_from_json( data_path, split, diff --git a/espresso/tools/Makefile b/espresso/tools/Makefile index 81ca5fb0f..8ed219a20 100644 --- a/espresso/tools/Makefile +++ b/espresso/tools/Makefile @@ -1,5 +1,5 @@ KALDI = -PYTHON_DIR = ~/anaconda3/bin +PYTHON_DIR = /export/b03/ywang/anaconda3/bin CXX ?= g++ @@ -30,6 +30,7 @@ kaldi: endif clean: openfst_cleaned + rm -rf lhotse rm -rf pychain rm -rf kaldi @@ -79,3 +80,8 @@ pychain: export PATH=$(PYTHON_DIR):$$PATH && \ cd pychain/openfst_binding && python3 setup.py install && \ cd ../pytorch_binding && python3 setup.py install + +.PHONY: lhotse +lhotse: + test -d lhotse || git clone https://github.com/lhotse-speech/lhotse.git + export PATH=$(PYTHON_DIR):$$PATH && cd lhotse && pip install -e . From 8926fa32acab1ee74cc9b690a635684395976926 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 5 Nov 2020 18:43:13 -0500 Subject: [PATCH 110/119] add a data prep example for lhotse --- espresso/data/asr_k2_dataset.py | 2 +- espresso/tools/.gitignore | 1 + examples/mobvoihotwords/cmd.sh | 20 +++ examples/mobvoihotwords/conf/gpu.conf | 10 ++ examples/mobvoihotwords/local/data_prep.py | 169 +++++++++++++++++++++ examples/mobvoihotwords/path.sh | 16 ++ 6 files changed, 217 insertions(+), 1 deletion(-) create mode 100644 examples/mobvoihotwords/cmd.sh create mode 100644 examples/mobvoihotwords/conf/gpu.conf create mode 100644 examples/mobvoihotwords/local/data_prep.py create mode 100644 examples/mobvoihotwords/path.sh diff --git a/espresso/data/asr_k2_dataset.py b/espresso/data/asr_k2_dataset.py index 0662016b4..2c22b44df 100644 --- a/espresso/data/asr_k2_dataset.py +++ b/espresso/data/asr_k2_dataset.py @@ -115,7 +115,7 @@ def __init__( [cut.num_frames if cut.has_features else cut.num_samples for cut in cuts] ) self.tgt_sizes = None - first_cut = cuts[self.cut_ids[0]] + first_cut = next(iter(cuts)) # assume all cuts have no supervisions if the first one does not if len(first_cut.supervisions) > 0: assert len(first_cut.supervisions) == 1, "Only single-supervision cuts are allowed" diff --git a/espresso/tools/.gitignore b/espresso/tools/.gitignore index 67d77be3d..8f95b0195 100644 --- a/espresso/tools/.gitignore +++ b/espresso/tools/.gitignore @@ -1,3 +1,4 @@ kaldi openfst* pychain +lhotse diff --git a/examples/mobvoihotwords/cmd.sh b/examples/mobvoihotwords/cmd.sh new file mode 100644 index 000000000..e531b4431 --- /dev/null +++ b/examples/mobvoihotwords/cmd.sh @@ -0,0 +1,20 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +#export train_cmd="run.pl --mem 4G" +#export cuda_cmd="run.pl --mem 4G --gpu 1" +#export decode_cmd="run.pl --mem 4G" + +# JHU setup (copy queue-freegpu.pl from ESPnet into utils/) +export train_cmd="queue.pl --mem 4G" +export cuda_cmd="queue-freegpu.pl --mem 8G --gpu 1 --config conf/gpu.conf" +export decode_cmd="queue.pl --mem 4G" diff --git a/examples/mobvoihotwords/conf/gpu.conf b/examples/mobvoihotwords/conf/gpu.conf new file mode 100644 index 000000000..5cc94adf2 --- /dev/null +++ b/examples/mobvoihotwords/conf/gpu.conf @@ -0,0 +1,10 @@ +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* +option mem=* -l mem_free=$0,ram_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -pe smp $0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 +option gpu=* -l 'hostname=c*,gpu=$0' -q g.q diff --git a/examples/mobvoihotwords/local/data_prep.py b/examples/mobvoihotwords/local/data_prep.py new file mode 100644 index 000000000..83007ebef --- /dev/null +++ b/examples/mobvoihotwords/local/data_prep.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import sys +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path + +import numpy as np + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def get_parser(): + parser = argparse.ArgumentParser( + description="data preparation for the MobvoiHotwords corpus" + ) + # fmt: off + parser.add_argument("--data-dir", default="data", type=str, help="data directory") + parser.add_argument("--seed", default=1, type=int, help="random seed") + parser.add_argument( + "--nj", default=1, type=int, help="number of jobs for features extraction" + ) + # fmt: on + + return parser + + +def main(args): + try: + # TODO use pip install once it's available + from espresso.tools.lhotse import CutSet, Mfcc, MfccConfig, LilcomFilesWriter, WavAugmenter + from espresso.tools.lhotse.manipulation import combine + from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords + except ImportError: + raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + + root_dir = Path(args.data_dir) + corpus_dir = root_dir / "MobvoiHotwords" + output_dir = root_dir + + # Download and extract the corpus + download_and_untar(root_dir) + + # Prepare manifests + mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) + logger.info( + "train/dev/test size: {}/{}/{}".format( + len(mobvoihotwords_manifests["train"]["recordings"]), + len(mobvoihotwords_manifests["dev"]["recordings"]), + len(mobvoihotwords_manifests["test"]["recordings"]) + ) + ) + + # Data augmentation + np.random.seed(args.seed) + # equivalent to Kaldi's mfcc_hires config + mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400)) + num_jobs = args.nj + for partition, manifests in mobvoihotwords_manifests.items(): + cut_set = CutSet.from_manifests( + recordings=manifests["recordings"], + supervisions=manifests["supervisions"], + ) + sampling_rate = next(iter(cut_set)).sampling_rate + with ProcessPoolExecutor(num_jobs) as ex: + if "train" in partition: + # original set + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_orig") as storage: + cut_set_orig = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=None, + executor=ex, + ) + # augmented with reverbration + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage: + cut_set_rev = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=WavAugmenter(effect_chain=reverb()), + excutor=ex, + ) + cut_set_rev = CutSet.from_cuts( + cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts + ) + # augmented with speed perturbation + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage: + cut_set_sp1p1 = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=WavAugmenter( + effect_chain=speed(sampling_rate=sampling_rate, factor=1.1) + ), + excutor=ex, + ) + cut_set_sp1p1 = CutSet.from_cuts( + cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts + ) + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage: + cut_set_sp0p9 = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=WavAugmenter( + effect_chain=speed(sampling_rate=sampling_rate, factor=0.9) + ), + excutor=ex, + ) + cut_set_sp0p9 = CutSet.from_cuts( + cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts + ) + # combine the original and augmented sets together + cut_set = combine( + cut_set_orig, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 + ) + else: # no augmentations for dev and test sets + with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage: + cut_set = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=None, + executor=ex, + ) + mobvoihotwords_manifests[partition]["cuts"] = cut_set + cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") + + +def reverb(*args, **kwargs): + """ + Returns a reverb effect for wav augmentation. + """ + import augment + effect_chain = augment.EffectChain() + # Reverb it makes the signal to have two channels, + # which we combine into 1 by running `channels` w/o parameters + effect_chain.reverb(50, 50, lambda: np.random.randint(1, 30)).channels() + return effect_chain + + +def speed(sampling_rate: int, factor: float): + """ + Returns a speed perturbation effect with for wav augmentation. + :param sampling_rate: a sampling rate value for which the effect will be created (resampling is needed for speed). + :param factor: speed perturbation factor + """ + import augment + effect_chain = augment.EffectChain() + # The speed effect changes the sampling ratio; we have to compensate for that. + # Here, we specify 'quick' options on both pitch and rate effects, to speed up things + effect_chain.speed("-q", lambda: factor).rate("-q", sampling_rate) + return effect_chain + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/mobvoihotwords/path.sh b/examples/mobvoihotwords/path.sh new file mode 100644 index 000000000..a2576bef6 --- /dev/null +++ b/examples/mobvoihotwords/path.sh @@ -0,0 +1,16 @@ +MAIN_ROOT=$PWD/../.. +export KALDI_ROOT=$MAIN_ROOT/espresso/tools/kaldi + +# BEGIN from kaldi path.sh +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 +. $KALDI_ROOT/tools/config/common_path.sh +export LC_ALL=C +# END + +export PATH=~/anaconda3/bin:$PATH +export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH +export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/openfst/lib:$LD_LIBRARY_PATH +export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$MAIN_ROOT/espresso/tools/lhotse:$MAIN_ROOT/espresso/tools/pychain:$PYTHONPATH +export PYTHONUNBUFFERED=1 From 083ea69fc91029bf23bb8911d620688ac843ca29 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 6 Nov 2020 21:03:00 -0500 Subject: [PATCH 111/119] add random split of negatives --- examples/mobvoihotwords/local/data_prep.py | 149 +++++++++++++++++---- 1 file changed, 123 insertions(+), 26 deletions(-) diff --git a/examples/mobvoihotwords/local/data_prep.py b/examples/mobvoihotwords/local/data_prep.py index 83007ebef..821c228a8 100644 --- a/examples/mobvoihotwords/local/data_prep.py +++ b/examples/mobvoihotwords/local/data_prep.py @@ -8,11 +8,24 @@ import logging import os import sys +from typing import List from concurrent.futures import ProcessPoolExecutor from pathlib import Path import numpy as np +from fairseq.data.data_utils import numpy_seed + +try: + # TODO use pip install once it's available + from espresso.tools.lhotse import ( + CutSet, Mfcc, MfccConfig, LilcomFilesWriter, SupervisionSet, WavAugmenter + ) + from espresso.tools.lhotse.manipulation import combine + from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords +except ImportError: + raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -31,7 +44,15 @@ def get_parser(): parser.add_argument("--data-dir", default="data", type=str, help="data directory") parser.add_argument("--seed", default=1, type=int, help="random seed") parser.add_argument( - "--nj", default=1, type=int, help="number of jobs for features extraction" + "--num-jobs", default=1, type=int, help="number of jobs for features extraction" + ) + parser.add_argument( + "--max-remaining-duration", default=0.3, type=float, + help="not split if the left-over duration is less than this many seconds" + ) + parser.add_argument( + "--overlap-duration", default=0.3, type=float, + help="overlap between adjacent segments while splitting negative recordings" ) # fmt: on @@ -39,14 +60,6 @@ def get_parser(): def main(args): - try: - # TODO use pip install once it's available - from espresso.tools.lhotse import CutSet, Mfcc, MfccConfig, LilcomFilesWriter, WavAugmenter - from espresso.tools.lhotse.manipulation import combine - from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords - except ImportError: - raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") - root_dir = Path(args.data_dir) corpus_dir = root_dir / "MobvoiHotwords" output_dir = root_dir @@ -68,36 +81,46 @@ def main(args): np.random.seed(args.seed) # equivalent to Kaldi's mfcc_hires config mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400)) - num_jobs = args.nj for partition, manifests in mobvoihotwords_manifests.items(): cut_set = CutSet.from_manifests( recordings=manifests["recordings"], supervisions=manifests["supervisions"], ) sampling_rate = next(iter(cut_set)).sampling_rate - with ProcessPoolExecutor(num_jobs) as ex: + with ProcessPoolExecutor(args.num_jobs) as ex: if "train" in partition: - # original set - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_orig") as storage: - cut_set_orig = cut_set.compute_and_store_features( + # split negative recordings into smaller chunks with lengths sampled from + # length distribution of positive recordings + pos_durs = get_positive_durations(manifests["supervisions"]) + with numpy_seed(args.seed): + cut_set = keep_positives_and_split_negatives( + cut_set, + pos_durs, + max_remaining_duration=args.max_remaining_duration, + overlap_duration=args.overlap_duration, + ) + # "clean" set + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_clean") as storage: + cut_set_clean = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, augmenter=None, executor=ex, ) - # augmented with reverbration - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage: - cut_set_rev = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augmenter=WavAugmenter(effect_chain=reverb()), - excutor=ex, - ) + # augmented with reverberation + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage: + with numpy_seed(args.seed): + cut_set_rev = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=WavAugmenter(effect_chain=reverb()), + excutor=ex, + ) cut_set_rev = CutSet.from_cuts( cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts ) # augmented with speed perturbation - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage: + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage: cut_set_sp1p1 = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, @@ -109,7 +132,7 @@ def main(args): cut_set_sp1p1 = CutSet.from_cuts( cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts ) - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage: + with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage: cut_set_sp0p9 = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, @@ -121,9 +144,9 @@ def main(args): cut_set_sp0p9 = CutSet.from_cuts( cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts ) - # combine the original and augmented sets together + # combine the clean and augmented sets together cut_set = combine( - cut_set_orig, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 + cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 ) else: # no augmentations for dev and test sets with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage: @@ -137,6 +160,80 @@ def main(args): cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") +def get_positive_durations(sup_set: SupervisionSet) -> List[float]: + """ + Get duration values of all positive recordings, assuming Supervison.text is + "FREETEXT" for all negative recordings, and SupervisionSegment.duration + equals to the corresponding Recording.duration. + """ + return [sup.dur for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")] + + +def keep_positives_and_split_negatives( + cut_set: CutSet, + durations: List[float], + max_remaining_duration: float = 0.3, + overlap_duration: float = 0.3, +) -> CutSet: + """ + Returns a new CutSet where all the positives are directly taken from the original + input cut set, and the negatives are obtained by splitting original negatives + into shorter chunks of random lengths drawn from the given length distribution + (here it is the empirical distribution of the positive recordings), There can + be overlap between chunks. + + Args: + cut_set (CutSet): original input cut set + durations (list[float]): list of durations to sample from + max_remaining_duration (float, optional): not split if the left-over + duration is less than this many seconds (default: 0.3). + overlap_duration (float, optional): overlap between adjacent segments + (default: None) + + Returns: + CutSet: a new cut set after split + """ + assert max_remaining_duration >= 0.0 and overlap_duration >= 0.0 + new_cuts = [] + for cut in cut_set: + assert len(cut.supervisions) == 1 + if cut.supervisions[0].text != "FREETEXT": # keep the positive as it is + new_cuts.append(cut) + else: + this_offset = cut.start + this_offset_relative = this_offset - cut.start + remaining_duration = cut.duration + this_dur = durations[np.random.randint(len(durations))] + while remaining_duration > this_dur + max_remaining_duration: + new_cut = cut.truncate( + offset=this_offset_relative, duration=this_dur, preserve_id=True + ) + new_cut = new_cut.with_id( + "{id}-{s:07d}-{e:07d}".format( + id=new_cut.id, + s=int(round(100 * this_offset_relative)), + e=int(round(100 * (this_offset_relative + this_dur))) + ) + ) + new_cuts.append(new_cut) + this_offset += this_dur - overlap_duration + this_offset_relative = this_offset - cut.start + remaining_duration -= this_dur - overlap_duration + this_dur = durations[np.random.randint(len(durations))] + + new_cut = cut.truncate(offset=this_offset_relative, preserve_id=True) + new_cut = new_cut.with_id( + "{id}-{s:07d}-{e:07d}".format( + id=new_cut.id, + s=int(round(100 * this_offset_relative)), + e=int(round(100 * cut.duration)) + ) + ) + new_cuts.append(new_cut) + + return CutSet.from_cuts(new_cuts) + + def reverb(*args, **kwargs): """ Returns a reverb effect for wav augmentation. From ecf84236b06960b6e5b087be57fc49ab5521ae46 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 8 Nov 2020 03:03:32 -0500 Subject: [PATCH 112/119] misc fixes --- espresso/__init__.py | 1 + espresso/data/__init__.py | 4 +- .../{asr_k2_dataset.py => k2_asr_dataset.py} | 31 ++- espresso/tasks/speech_recognition_hybrid.py | 13 +- examples/mobvoihotwords/cmd.sh | 2 +- examples/mobvoihotwords/local/data_prep.py | 202 +++++++++++------- examples/mobvoihotwords/path.sh | 3 +- examples/mobvoihotwords/run.sh | 29 +++ examples/mobvoihotwords/utils | 1 + 9 files changed, 192 insertions(+), 94 deletions(-) rename espresso/data/{asr_k2_dataset.py => k2_asr_dataset.py} (88%) mode change 100644 => 100755 examples/mobvoihotwords/local/data_prep.py create mode 100755 examples/mobvoihotwords/run.sh create mode 120000 examples/mobvoihotwords/utils diff --git a/espresso/__init__.py b/espresso/__init__.py index 666272541..9db006656 100644 --- a/espresso/__init__.py +++ b/espresso/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import espresso.tools # noqa import espresso.criterions # noqa import espresso.models # noqa import espresso.modules # noqa diff --git a/espresso/data/__init__.py b/espresso/data/__init__.py index 8edd372e5..b414edae9 100644 --- a/espresso/data/__init__.py +++ b/espresso/data/__init__.py @@ -6,7 +6,7 @@ from .asr_bucket_pad_length_dataset import FeatBucketPadLengthDataset, TextBucketPadLengthDataset from .asr_chain_dataset import AsrChainDataset, NumeratorGraphDataset from .asr_dataset import AsrDataset -from .asr_k2_dataset import AsrK2Dataset +from .k2_asr_dataset import K2AsrDataset from .asr_dictionary import AsrDictionary from .asr_xent_dataset import AliScpCachedDataset, AsrXentDataset from .feat_text_dataset import ( @@ -21,13 +21,13 @@ "AsrChainDataset", "AsrDataset", "AsrDictionary", - "AsrK2Dataset", "AsrTextDataset", "AsrXentDataset", "FeatBucketPadLengthDataset", "FeatScpCachedDataset", "FeatScpDataset", "FeatScpInMemoryDataset", + "K2AsrDataset", "NumeratorGraphDataset", "TextBucketPadLengthDataset", ] diff --git a/espresso/data/asr_k2_dataset.py b/espresso/data/k2_asr_dataset.py similarity index 88% rename from espresso/data/asr_k2_dataset.py rename to espresso/data/k2_asr_dataset.py index 2c22b44df..d67d707ff 100644 --- a/espresso/data/asr_k2_dataset.py +++ b/espresso/data/k2_asr_dataset.py @@ -6,7 +6,7 @@ import logging import os import re -from typing import Dict, List +from typing import Any, Dict, List, Optional import numpy as np @@ -17,12 +17,23 @@ import espresso.tools.utils as speech_utils try: # TODO use pip install once it's available - from espresso.tools.lhotse.cut import CutSet + from espresso.tools.lhotse.lhotse import CutSet except ImportError: raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") -def collate(samples, pad_to_length=None, pad_to_multiple=1): +def collate( + samples: List[Dict[str, Any]], + pad_to_length: Optional[Dict[str, int]] = None, + pad_to_multiple: int = 1, +) -> Dict[str, Any]: + """Collate samples into a batch. We use :func:`speech_utils.collate_frames` + to collate and pad input frames, and PyTorch's :func:`default_collate` + to collate and pad target/supervisions (following the example provided in Lhotse). + Samples in the batch are in descending order of their input frame lengths. + It also allows to specify the padded input length and further enforce the length + to be a multiple of `pad_to_multiple` + """ if len(samples) == 0: return {} @@ -92,12 +103,12 @@ def update(d: Dict, **kwargs) -> Dict: return batch -class AsrK2Dataset(FairseqDataset): +class K2AsrDataset(FairseqDataset): """ A K2 Dataset for ASR. Args: - cuts (lhotse.CutSet): Lhotse CutSet to wrap + cuts (lhotse.CutSet): instance of Lhotse's CutSet to wrap shuffle (bool, optional): shuffle dataset elements before batching (default: True). pad_to_multiple (int, optional): pad src lengths to a multiple of this value @@ -165,14 +176,18 @@ def __getitem__(self, index): def __len__(self): return len(self.cuts) - def collater(self, samples, pad_to_length=None): + def collater( + self, + samples: List[Dict[str, Any]], + pad_to_length: Optional[Dict[str, int]] = None, + ) -> Dict[str, Any]: """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate pad_to_length (dict, optional): a dictionary of {"source": source_pad_to_length} - to indicate the max length to pad to in source and target respectively. + to indicate the max length to pad to in source. Returns: dict: a mini-batch with the following keys: @@ -188,7 +203,7 @@ def collater(self, samples, pad_to_length=None): - `src_lengths` (IntTensor): 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - - `target` (List[Dict[str, Any]]): an List representing a batch of + - `target` (List[Dict[str, Any]]): a List representing a batch of supervisions """ return collate( diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index c8600e263..711b561eb 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -23,11 +23,11 @@ from espresso.data import ( AliScpCachedDataset, AsrChainDataset, - AsrK2Dataset, AsrXentDataset, AsrDictionary, AsrTextDataset, FeatScpCachedDataset, + K2AsrDataset, NumeratorGraphDataset, ) @@ -151,7 +151,7 @@ class SpeechRecognitionHybridConfig(FairseqDataclass): def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1): try: # TODO use pip install once it's available - from espresso.tools.lhotse.cut import CutSet + from espresso.tools.lhotse.lhotse import CutSet except ImportError: raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") @@ -161,7 +161,7 @@ def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, cut_set = CutSet.from_json(data_json_path) logger.info("{} {} examples".format(data_json_path, len(cut_set))) - return AsrK2Dataset(cut_set, shuffle=shuffle, pad_to_multiple=pad_to_multiple) + return K2AsrDataset(cut_set, shuffle=shuffle, pad_to_multiple=pad_to_multiple) def get_asr_dataset_from_json( @@ -413,6 +413,12 @@ def setup_task(cls, cfg: SpeechRecognitionHybridConfig, **kwargs): assert len(paths) > 0 data_path = paths[0] split = cfg.valid_subset.split(",")[0] # valid set is usually much smaller than train set, so it's faster + if cfg.use_k2_dataset: + try: + feat_dim = get_k2_dataset_from_json(data_path, split).feat_dim + except FileNotFoundError: + feat_dim = get_k2_dataset_from_json(data_path, cfg.gen_subset).feat_dim + return cls(cfg, dictionary, feat_dim) try: src_dataset = get_asr_dataset_from_json(data_path, split, dictionary, combine=False).src except FileNotFoundError: @@ -460,7 +466,6 @@ def load_dataset( pad_to_multiple=self.cfg.required_seq_len_multiple, seed=self.cfg.seed, ) - self.feat_dim = self.datasets[split].feat_dim return self.datasets[split] = get_asr_dataset_from_json( diff --git a/examples/mobvoihotwords/cmd.sh b/examples/mobvoihotwords/cmd.sh index e531b4431..382076813 100644 --- a/examples/mobvoihotwords/cmd.sh +++ b/examples/mobvoihotwords/cmd.sh @@ -15,6 +15,6 @@ #export decode_cmd="run.pl --mem 4G" # JHU setup (copy queue-freegpu.pl from ESPnet into utils/) -export train_cmd="queue.pl --mem 4G" +export train_cmd="queue.pl --mem 32G" export cuda_cmd="queue-freegpu.pl --mem 8G --gpu 1 --config conf/gpu.conf" export decode_cmd="queue.pl --mem 4G" diff --git a/examples/mobvoihotwords/local/data_prep.py b/examples/mobvoihotwords/local/data_prep.py old mode 100644 new mode 100755 index 821c228a8..df64dbd04 --- a/examples/mobvoihotwords/local/data_prep.py +++ b/examples/mobvoihotwords/local/data_prep.py @@ -9,20 +9,24 @@ import os import sys from typing import List +from collections import defaultdict from concurrent.futures import ProcessPoolExecutor +import multiprocessing from pathlib import Path import numpy as np from fairseq.data.data_utils import numpy_seed + try: # TODO use pip install once it's available - from espresso.tools.lhotse import ( - CutSet, Mfcc, MfccConfig, LilcomFilesWriter, SupervisionSet, WavAugmenter + from espresso.tools.lhotse.lhotse import ( + CutSet, Mfcc, MfccConfig, LilcomFilesWriter, RecordingSet, SupervisionSet ) - from espresso.tools.lhotse.manipulation import combine - from espresso.tools.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords + from espresso.tools.lhotse.lhotse.augmentation import SoxEffectTransform, RandomValue + from espresso.tools.lhotse.lhotse.manipulation import combine + from espresso.tools.lhotse.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords except ImportError: raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") @@ -33,7 +37,7 @@ level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) -logger = logging.getLogger(__name__) +logger = logging.getLogger("mobvoihotwords.data_prep") def get_parser(): @@ -44,7 +48,7 @@ def get_parser(): parser.add_argument("--data-dir", default="data", type=str, help="data directory") parser.add_argument("--seed", default=1, type=int, help="random seed") parser.add_argument( - "--num-jobs", default=1, type=int, help="number of jobs for features extraction" + "--num-workers", default=1, type=int, help="number of workers for features extraction" ) parser.add_argument( "--max-remaining-duration", default=0.3, type=float, @@ -64,11 +68,25 @@ def main(args): corpus_dir = root_dir / "MobvoiHotwords" output_dir = root_dir - # Download and extract the corpus + logger.info(f"Download and extract the corpus") download_and_untar(root_dir) - # Prepare manifests - mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) + logger.info(f"Prepare the manifests") + partitions = ["train", "dev", "test"] + if all( + (output_dir / f"{key}_{part}.json").is_file() + for key in ["recordings", "supervisions"] for part in partitions + ): + logger.info(f"All the manifests files are found in {output_dir}. Load from them directly") + mobvoihotwords_manifests = defaultdict(dict) + for part in partitions: + mobvoihotwords_manifests[part] = { + "recordings": RecordingSet.from_json(output_dir / f"recordings_{part}.json"), + "supervisions": SupervisionSet.from_json(output_dir / f"supervisions_{part}.json") + } + else: + logger.info("It may take long time") + mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) logger.info( "train/dev/test size: {}/{}/{}".format( len(mobvoihotwords_manifests["train"]["recordings"]), @@ -81,16 +99,17 @@ def main(args): np.random.seed(args.seed) # equivalent to Kaldi's mfcc_hires config mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400)) - for partition, manifests in mobvoihotwords_manifests.items(): + for part, manifests in mobvoihotwords_manifests.items(): cut_set = CutSet.from_manifests( recordings=manifests["recordings"], supervisions=manifests["supervisions"], ) sampling_rate = next(iter(cut_set)).sampling_rate - with ProcessPoolExecutor(args.num_jobs) as ex: - if "train" in partition: + with ProcessPoolExecutor(args.num_workers, mp_context=multiprocessing.get_context("spawn")) as ex: + if part == "train": # split negative recordings into smaller chunks with lengths sampled from # length distribution of positive recordings + logger.info(f"Split negative recordings in '{part}' set") pos_durs = get_positive_durations(manifests["supervisions"]) with numpy_seed(args.seed): cut_set = keep_positives_and_split_negatives( @@ -99,65 +118,103 @@ def main(args): max_remaining_duration=args.max_remaining_duration, overlap_duration=args.overlap_duration, ) + # "clean" set - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_clean") as storage: - cut_set_clean = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augmenter=None, - executor=ex, - ) - # augmented with reverberation - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_rev") as storage: - with numpy_seed(args.seed): - cut_set_rev = cut_set.compute_and_store_features( + logger.info(f"Extract features for '{part}' set") + json_path = output_dir / f"cuts_{part}_clean.json.gz" + if json_path.is_file(): + logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") + cut_set_clean = CutSet.from_json(json_path) + else: + with LilcomFilesWriter(f"{output_dir}/feats_{part}_clean") as storage: + cut_set_clean = cut_set.compute_and_store_features( extractor=mfcc, storage=storage, - augmenter=WavAugmenter(effect_chain=reverb()), - excutor=ex, + augment_fn=None, + executor=ex, ) + cut_set_clean.to_json(json_path) + + # augmented with reverberation + logger.info(f"Extract features for '{part}' set with reverberation") + json_path = output_dir / f"cuts_{part}_rev.json.gz" + if json_path.is_file(): + logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") + cut_set_rev = CutSet.from_json(json_path) + else: + augment_fn = SoxEffectTransform(effects=reverb(sampling_rate=sampling_rate)) + with LilcomFilesWriter(f"{output_dir}/feats_{part}_rev") as storage: + with numpy_seed(args.seed): + cut_set_rev = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augment_fn=augment_fn, + executor=ex, + ) cut_set_rev = CutSet.from_cuts( - cut.with_id("rev-" + cut.id) for cut in cut_set_rev.cuts + cut.with_id("rev-" + cut.id) for cut in cut_set_rev ) + cut_set_rev.to_json(json_path) + # augmented with speed perturbation - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp1.1") as storage: - cut_set_sp1p1 = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augmenter=WavAugmenter( - effect_chain=speed(sampling_rate=sampling_rate, factor=1.1) - ), - excutor=ex, - ) + logger.info(f"Extract features for '{part}' set with speed perturbation") + json_path = output_dir / f"cuts_{part}_sp1.1.json.gz" + if json_path.is_file(): + logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") + cut_set_sp1p1 = CutSet.from_json(json_path) + else: + augment_fn = SoxEffectTransform(effects=speed(sampling_rate=sampling_rate, factor=1.1)) + with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp1.1") as storage: + cut_set_sp1p1 = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augment_fn=augment_fn, + executor=ex, + ) cut_set_sp1p1 = CutSet.from_cuts( - cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1.cuts - ) - with LilcomFilesWriter(f"{output_dir}/feats_{partition}_sp0.9") as storage: - cut_set_sp0p9 = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augmenter=WavAugmenter( - effect_chain=speed(sampling_rate=sampling_rate, factor=0.9) - ), - excutor=ex, + cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1 ) + cut_set_sp1p1.to_json(json_path) + json_path = output_dir / f"cuts_{part}_sp0.9.json.gz" + if json_path.is_file(): + logger.info(f"{json_path} exists, skip the extraction") + cut_set_sp1p1 = CutSet.from_json(json_path) + else: + augment_fn = SoxEffectTransform(effects=speed(sampling_rate=sampling_rate, factor=0.9)) + with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp0.9") as storage: + cut_set_sp0p9 = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augment_fn=augment_fn, + executor=ex, + ) cut_set_sp0p9 = CutSet.from_cuts( - cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9.cuts + cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9 ) + cut_set_sp0p9.to_json(json_path) + # combine the clean and augmented sets together + logger.info(f"Combine all the features above") cut_set = combine( cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 ) else: # no augmentations for dev and test sets - with LilcomFilesWriter(f"{output_dir}/feats_{partition}") as storage: - cut_set = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augmenter=None, - executor=ex, - ) - mobvoihotwords_manifests[partition]["cuts"] = cut_set - cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") + logger.info(f"extract features for '{part}' set") + json_path = output_dir / f"cuts_{part}.json.gz" + if json_path.is_file(): + logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") + cut_set = CutSet.from_json(json_path) + else: + with LilcomFilesWriter(f"{output_dir}/feats_{part}") as storage: + cut_set = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augmenter=None, + executor=ex, + ) + + mobvoihotwords_manifests[part]["cuts"] = cut_set + cut_set.to_json(output_dir / f"cuts_{part}.json.gz") def get_positive_durations(sup_set: SupervisionSet) -> List[float]: @@ -166,7 +223,7 @@ def get_positive_durations(sup_set: SupervisionSet) -> List[float]: "FREETEXT" for all negative recordings, and SupervisionSegment.duration equals to the corresponding Recording.duration. """ - return [sup.dur for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")] + return [sup.duration for sup in sup_set.filter(lambda seg: seg.text != "FREETEXT")] def keep_positives_and_split_negatives( @@ -234,30 +291,19 @@ def keep_positives_and_split_negatives( return CutSet.from_cuts(new_cuts) -def reverb(*args, **kwargs): - """ - Returns a reverb effect for wav augmentation. - """ - import augment - effect_chain = augment.EffectChain() - # Reverb it makes the signal to have two channels, - # which we combine into 1 by running `channels` w/o parameters - effect_chain.reverb(50, 50, lambda: np.random.randint(1, 30)).channels() - return effect_chain +def reverb(sampling_rate: int) -> List[List[str]]: + return [ + ["reverb", 50, 50, RandomValue(1, 30)], + ["remix", "-"], # Merge all channels (reverb changes mono to stereo) + ] -def speed(sampling_rate: int, factor: float): - """ - Returns a speed perturbation effect with for wav augmentation. - :param sampling_rate: a sampling rate value for which the effect will be created (resampling is needed for speed). - :param factor: speed perturbation factor - """ - import augment - effect_chain = augment.EffectChain() - # The speed effect changes the sampling ratio; we have to compensate for that. - # Here, we specify 'quick' options on both pitch and rate effects, to speed up things - effect_chain.speed("-q", lambda: factor).rate("-q", sampling_rate) - return effect_chain +def speed(sampling_rate: int, factor: float) -> List[List[str]]: + return [ + # speed perturbation with a factor + ["speed", factor], + ["rate", sampling_rate], # Resample back to the original sampling rate (speed changes it) + ] if __name__ == "__main__": diff --git a/examples/mobvoihotwords/path.sh b/examples/mobvoihotwords/path.sh index a2576bef6..180398bce 100644 --- a/examples/mobvoihotwords/path.sh +++ b/examples/mobvoihotwords/path.sh @@ -9,8 +9,9 @@ export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/tools/sctk/bin export LC_ALL=C # END -export PATH=~/anaconda3/bin:$PATH +export PATH=/export/b03/ywang/anaconda3/bin:$PATH export PATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$PATH export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/openfst/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=$MAIN_ROOT/espresso/tools/lhotse/tools/deps/sox-code/src/.libs:$LD_LIBRARY_PATH export PYTHONPATH=$MAIN_ROOT:$MAIN_ROOT/espresso:$MAIN_ROOT/espresso/tools:$MAIN_ROOT/espresso/tools/lhotse:$MAIN_ROOT/espresso/tools/pychain:$PYTHONPATH export PYTHONUNBUFFERED=1 diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh new file mode 100755 index 000000000..335f88070 --- /dev/null +++ b/examples/mobvoihotwords/run.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +set -e -o pipefail + +stage=0 + +. ./path.sh +. ./cmd.sh +. ./utils/parse_options.sh + +if [ ${stage} -le 0 ]; then + echo "Stage 0: Data Preparation" + mkdir -p data/log + ${train_cmd} data/log/data_prep.log \ + python3 local/data_prep.py --data-dir data --seed 1 --num-workers 16 \ + --max-remaining-duration 0.3 --overlap-duration 0.3 +fi + +if [ ${stage} -le 1 ]; then + echo "Stage 1: Graph Generation" +fi + +if [ ${stage} -le 2 ]; then + echo "Stage 2: Model Training" +fi diff --git a/examples/mobvoihotwords/utils b/examples/mobvoihotwords/utils new file mode 120000 index 000000000..bc8958e91 --- /dev/null +++ b/examples/mobvoihotwords/utils @@ -0,0 +1 @@ +../../espresso/tools/kaldi/egs/wsj/s5/utils \ No newline at end of file From 3713dc1ad3d0be812272ecdb644e44ac6fcf3d26 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 13 Nov 2020 04:22:08 -0500 Subject: [PATCH 113/119] k2 training related (not yet done) --- espresso/criterions/k2_lf_mmi_loss.py | 181 ++++++++++++++++++ espresso/models/speech_tdnn.py | 14 ++ espresso/tools/Makefile | 7 + .../mobvoihotwords/local/generate_graphs.py | 85 ++++++++ examples/mobvoihotwords/run.sh | 151 ++++++++++++++- 5 files changed, 434 insertions(+), 4 deletions(-) create mode 100644 espresso/criterions/k2_lf_mmi_loss.py create mode 100755 examples/mobvoihotwords/local/generate_graphs.py diff --git a/espresso/criterions/k2_lf_mmi_loss.py b/espresso/criterions/k2_lf_mmi_loss.py new file mode 100644 index 000000000..93dc62181 --- /dev/null +++ b/espresso/criterions/k2_lf_mmi_loss.py @@ -0,0 +1,181 @@ +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import logging +import math +from typing import List + +import torch + +from fairseq import utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging import metrics +from omegaconf import II + +try: + import k2 +except ImportError: + raise ImportError("Please install k2 by `pip install k2`") + + +logger = logging.getLogger(__name__) + + +@dataclass +class K2LatticeFreeMMICriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + denominator_fst_path: str = field( + default="???", metadata={"help": "path to the denominator fst file (torch saved)"} + ) + HCL_fst_path: str = field( + default="???", metadata={"help": "path to the HCL fst file (torch saved)"} + ) + word_symbol_table_path: str = field( + default="???", metadata={"help": "path to the word symbol table file"} + ) + leaky_hmm_coefficient: float = field( + default=1.0e-05, + metadata={"help": "leaky-hmm coefficient for the denominator"}, + ) + xent_regularization_coefficient: float = field( + default=0.0, + metadata={"help": "cross-entropy regularization coefficient"}, + ) + output_l2_regularization_coefficient: float = field( + default=0.0, + metadata={"help": "L2 regularization coefficient for the network's output"}, + ) + + +def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.SymbolTable): + word_ids_list = [] + for text in texts: + filtered_text = [ + word if word in symbols._sym2id else "" for word in text.split(" ") + ] + word_ids = [symbols.get(word) for word in filtered_text] + word_ids_list.append(word_ids) + + fsa = k2.linear_fsa(word_ids_list) # create an FsaVec from a list of list of word ids + num_graph = k2.intersect(fsa, HCL_fst_inv).invert_() + num_graph = k2.add_epsilon_self_loops(num_graph) + return num_graph + + +@register_criterion("k2_lattice_free_mmi", dataclass=K2LatticeFreeMMICriterionConfig) +class K2LatticeFreeMMICriterion(FairseqCriterion): + + def __init__( + self, task, sentence_avg, denominator_fst_path, HCL_fst_path, word_symbol_table_path, + leaky_hmm_coefficient, xent_regularization_coefficient, output_l2_regularization_coefficient, + ): + super().__init__(task) + + self.sentence_avg = sentence_avg + self.den_graph = k2.create_fsa_vec( + k2.Fsa.from_dict(torch.load(denominator_fst_path)) + ) # has to be FsaVec to be able to intersect with a batch of dense fsas + self.den_graph.scores.requires_grad_(False) + self.HCL_fst_inv = k2.Fsa.from_dict(torch.load(HCL_fst_path)).invert_() + self.symbol_table = k2.SymbolTable.from_file(word_symbol_table_path) + self.leaky_hmm_coefficient = leaky_hmm_coefficient + self.xent_regularize = xent_regularization_coefficient + self.output_l2_regularize = output_l2_regularization_coefficient + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(net_output, sample, reduce=reduce) + + sample_size = ( + sample["target"].batch_size if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + def compute_loss(self, net_output, sample, reduce=True): + # create the dense fsts from the network's output + encoder_out = net_output.encoder_out.transpose(0, 1) # T x B x V -> B x T x V + out_lengths = net_output.src_lengths.long() # B + supervision_segments = torch.stack( + # seq_index, start_frame, lengths + (sample["target"]["sequence_idx"], sample["target"]["start_frame"], out_lengths), + dim=1 + ) + dense_fsa_vec = k2.DenseFsaVec(encoder_out, supervision_segments) + + # numerator computation + num_graphs = create_numerator_graphs(sample["target"]["text"], self.HCL_fst_inv, self.symbol_table) + num_graphs.to_(encoder_out.device) + num_graphs.scores.requires_grad_(False) + num_graphs_unrolled = k2.intersect_dense_pruned( + num_graphs, dense_fsa_vec, beam=100000, max_active_states=10000, min_active_states=0 + ) + num_scores = k2.get_tot_scores(num_graphs_unrolled, log_semiring=False, use_float_scores=True) + + # denominator computation + self.den_graph.to_(encoder_out.device) + den_graph_unrolled = k2.intersect_dense_pruned( + self.den_graph, dense_fsa_vec, beam=100000, max_active_states=10000, min_active_states=0 + ) + den_scores = k2.get_tot_scores(den_graph_unrolled, log_semiring=False, use_float_scores=True) + + # obtain the loss + loss = -num_scores + den_scores + nll_loss = loss.clone().detach() + if self.xent_regularize > 0.0: + loss -= self.xent_regularize * num_scores + + if self.output_l2_regularize > 0.0: + encoder_padding_mask = net_output.encoder_padding_mask + encoder_out_squared = encoder_out.pow(2.0) + if encoder_padding_mask is not None: + pad_mask = encoder_padding_mask.transpose(0, 1).unsqueeze(-1) # T x B -> B x T x 1 + encoder_out_squared.masked_fill_(pad_mask, 0.0) + loss += 0.5 * self.output_l2_regularize * encoder_out_squared.sum() + + return loss, nll_loss + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=7 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=7 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, round=4) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/espresso/models/speech_tdnn.py b/espresso/models/speech_tdnn.py index 7443a3b65..9014e432f 100644 --- a/espresso/models/speech_tdnn.py +++ b/espresso/models/speech_tdnn.py @@ -339,3 +339,17 @@ def base_architecture(args): @register_model_architecture("speech_tdnn", "speech_tdnn_wsj") def tdnn_wsj(args): base_architecture(args) + + +@register_model_architecture("speech_tdnn", "speech_tdnn_mobvoi") +def tdnn_mobvoi(args): + args.dropout = getattr(args, "dropout", 0.0) + args.hidden_sizes = getattr(args, "hidden_sizes", "64") + args.kernel_sizes = getattr(args, "kernel_sizes", "[3] * 5") + args.strides = getattr(args, "strides", "1") + args.dilations = getattr(args, "dilations", "3") + args.num_layers = getattr(args, "num_layers", 5) + args.residual = getattr(args, "residual", False) + args.dropout_in = getattr(args, "dropout_in", args.dropout) + args.dropout_out = getattr(args, "dropout_out", args.dropout) + base_architecture(args) diff --git a/espresso/tools/Makefile b/espresso/tools/Makefile index 8ed219a20..a95186e58 100644 --- a/espresso/tools/Makefile +++ b/espresso/tools/Makefile @@ -1,5 +1,6 @@ KALDI = PYTHON_DIR = /export/b03/ywang/anaconda3/bin +CMAKE = /home/ywang/cmake-3.18.4-Linux-x86_64/bin/cmake CXX ?= g++ @@ -30,6 +31,7 @@ kaldi: endif clean: openfst_cleaned + rm -rf k2 rm -rf lhotse rm -rf pychain rm -rf kaldi @@ -85,3 +87,8 @@ pychain: lhotse: test -d lhotse || git clone https://github.com/lhotse-speech/lhotse.git export PATH=$(PYTHON_DIR):$$PATH && cd lhotse && pip install -e . + +.PHONY: k2 +k2: + test -d k2 || git clone https://github.com/k2-fsa/k2.git + cd k2 && mkdir -p build && cd build && $(CMAKE) .. && $(MAKE) diff --git a/examples/mobvoihotwords/local/generate_graphs.py b/examples/mobvoihotwords/local/generate_graphs.py new file mode 100755 index 000000000..49df9e628 --- /dev/null +++ b/examples/mobvoihotwords/local/generate_graphs.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import sys + +import torch + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("mobvoihotwords.generate_graphs") + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Generate graphs for training" + ) + # fmt: off + parser.add_argument("--hmm-paths", nargs="+", help="list of HMM paths", required=True) + parser.add_argument("--lexicon-fst-path", type=str, help="path to the lexicon fst", required=True) + parser.add_argument("--phone-lm-fsa-path", type=str, help="path to the phone LM fsa", required=True) + parser.add_argument("--out-dir", type=str, default="data", help="directory to save output graphs") + # fmt: on + + return parser + + +def main(args): + try: + import k2 + except ImportError: + raise ImportError("Please install k2 by `pip install k2`") + + H = [] + hmm_paths = args.hmm_paths + for hmm in hmm_paths: + with open(hmm, "r", encoding="utf-8") as f: + H.append(k2.Fsa.from_openfst(f.read(), acceptor=False)) + H[-1] = k2.arc_sort(H[-1]) + #TODO: H = fsa.union([H]) + H = k2.arc_sort(H.invert_()) + + with open(args.lexicon_fst_path, "r", encoding="utf-8") as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + L = k2.arc_sort(L.invert_()).invert_() # sort on olabels + + with open(args.phone_lm_fst_path, "r", encoding="utf-8") as f: + phone_lm = k2.Fsa.from_openfst(f.read(), acceptor=True) + phone_lm = k2.arc_sort(phone_lm) + + # emulate composition + if hasattr(L, "aux_symbols"): + setattr(L, "temp_symbols", L.aux_symbols) + delattr(L, "aux_symbols") + HL = k2.intersect(H, L) + if hasattr(L, "temp_symbols"): + setattr(L, "aux_symbols", L.temp_symbols) + delattr(L, "temp_symbols") + HL = k2.arc_sort(HL) + save_path = os.path.join(args.out_dir, "HL.pt") + torch.save(HL.as_dict(), save_path) + logger.info(f"save HL as {save_path}") + + den_graph = k2.intersect(H, phone_lm).invert_() + den_graph = k2.arc_sort(den_graph) + #den_graph = k2.add_epsilon_self_loops(den_graph) + save_path = os.path.join(args.out_dir, "denominator.pt") + torch.save(den_graph.as_dict(), save_path) + logger.info(f"save the denominator graph as {save_path}") + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh index 335f88070..8d8d77d19 100755 --- a/examples/mobvoihotwords/run.sh +++ b/examples/mobvoihotwords/run.sh @@ -7,23 +7,166 @@ set -e -o pipefail stage=0 +ngpus=1 # num GPUs for multiple GPUs training within a single node; should match those in $free_gpu +free_gpu= # comma-separated available GPU ids, eg., "0" or "0,1"; automatically assigned if on CLSP grid + +# model and data related +affix= +test_set="dev eval" +checkpoint=checkpoint_best.pt +wake_word0="HiXiaowen" +wake_word1="NihaoWenwen" + . ./path.sh . ./cmd.sh . ./utils/parse_options.sh +dir=exp/tdnn_k2_${affix:+_$affix} + if [ ${stage} -le 0 ]; then echo "Stage 0: Data Preparation" mkdir -p data/log - ${train_cmd} data/log/data_prep.log \ - python3 local/data_prep.py --data-dir data --seed 1 --num-workers 16 \ - --max-remaining-duration 0.3 --overlap-duration 0.3 + log_file=data/log/data_prep.log + $train_cmd $log_file ./local/data_prep.py --data-dir data --seed 1 \ + --num-workers 16 --max-remaining-duration 0.3 --overlap-duration 0.3 fi if [ ${stage} -le 1 ]; then - echo "Stage 1: Graph Generation" + echo "Stage 1: Graphs Generation" + echo "Prepare the lexicon" + mkdir -p data/lang + cat > data/lang/lexiconp.txt < 1.0 SIL +EOF + + utils/lang/make_lexicon_fst.py --sil-prob 0.5 --sil-phone SIL \ + data/lang/lexiconp.txt > data/lang/L.fst.txt.sym + + echo "Prepare phones symbol table" + cat > data/lang/phones.txt < 0 +SIL 1 +hixiaowen 2 +nihaowenwen 3 +freetext 4 +EOF + + echo "Prepare words symbol table" + cat > data/lang/words.txt < 0 + 1 +FREETEXT 2 +HiXiaowen 3 +NihaoWenwen 4 +EOF + + utils/sym2int.pl -f 3 data/lang/phones.txt data/lang/L.fst.txt + + echo "Prepare HMMs for phones" + id_sil=`cat data/lang/phones.txt | grep "SIL" | awk '{print $2}'` + id_freetext=`cat data/lang/phones.txt | grep "freetext" | awk '{print $2}'` + id_word0=`cat data/lang/phones.txt | grep "hixiaowen" | awk '{print $2}'` + id_word1=`cat data/lang/phones.txt | grep "nihaowenwen" | awk '{print $2}'` + id_freetext=`cat data/lang/phones.txt | grep "freetext" | awk '{print $2}'` + + cat > data/lang/hmm_sil.txt < data/lang/hmm_freetext.txt < data/lang/hmm_hixiaowen.txt < data/lang/hmm_nihaowenwen.txt < data/lang/phone_lm.txt +0 1 $id_sil +0 5 $id_sil +1 2 $id_word0 +2 3 $id_sil +1 4 $id_word1 +4 5 $id_sil +1 6 $id_freetext +6 7 $id_sil +3 2.3 +5 2.3 +7 0.0 +EOF + + echo "Generate graphs for training" + local/generate_graphs.py --hmm-paths data/lang/hmm_{sil,freetext,hixiaowen,nihaowenwen}.txt \ + --lexicon-fst-path data/lang/L.fst.txt --phone-lm-fsa-path data/lang/phone_lm.txt \ + --out-dir data fi +[ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) || \ + echo "Unable to get $ngpus GPUs" +[ -z "$free_gpu" ] && echo "$0: please specify --free-gpu" && exit 1; +[ $(echo $free_gpu | sed 's/,/ /g' | awk '{print NF}') -ne "$ngpus" ] && \ + echo "number of GPU ids in --free-gpu=$free_gpu does not match --ngpus=$ngpus" && exit 1; + +num_targets=26 # hard-coded for now. It's equal to the number of different labels in data/lang/hmm_*.txt + if [ ${stage} -le 2 ]; then echo "Stage 2: Model Training" + opts="" + valid_subset=dev + mkdir -p $dir/log + log_file=$dir/log/train.log + [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" + update_freq=1 + CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_hybrid --seed 1 \ + --log-interval $((1500/ngpus/update_freq)) --log-format simple \ + --num-workers 0 --data-buffer-size 0 --max-tokens 25600 --batch-size 128 --empty-cache-freq 50 \ + --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ + --distributed-world-size $ngpus --arch speech_tdnn_mobvoi \ + --max-epoch 15 --optimizer adam --lr 0.001 --weight-decay 0.0 \ + --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ + --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((1500/ngpus/update_freq)) \ + --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ + --criterion k2_lattice_free_mmi --num-targets $num_targets --word-symbol-path data/lang/words.txt \ + --denominator-fst-path data/denominator.pt --HCL-fst-path data/HL.pt \ + --max-source-positions 9999 --max-target-positions 9999 $opts 2>&1 | tee $log_file +fi + +if [ ${stage} -le 3 ]; then + echo "Stage 3: Decoding" fi From 4fb0d8776b6db26ea8798dd89e32c19d7fd6f38a Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 14 Nov 2020 02:03:32 -0500 Subject: [PATCH 114/119] fixes --- espresso/criterions/k2_lf_mmi_loss.py | 47 ++++++++----------- espresso/data/k2_asr_dataset.py | 33 +++++-------- .../mobvoihotwords/local/generate_graphs.py | 23 +++++---- examples/mobvoihotwords/run.sh | 18 +++---- 4 files changed, 51 insertions(+), 70 deletions(-) diff --git a/espresso/criterions/k2_lf_mmi_loss.py b/espresso/criterions/k2_lf_mmi_loss.py index 93dc62181..e221e07d6 100644 --- a/espresso/criterions/k2_lf_mmi_loss.py +++ b/espresso/criterions/k2_lf_mmi_loss.py @@ -6,13 +6,16 @@ from dataclasses import dataclass, field import logging import math -from typing import List +from typing import Any, Dict, List, Optional import torch from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.tasks import FairseqTask from fairseq.logging import metrics from omegaconf import II @@ -37,18 +40,10 @@ class K2LatticeFreeMMICriterionConfig(FairseqDataclass): word_symbol_table_path: str = field( default="???", metadata={"help": "path to the word symbol table file"} ) - leaky_hmm_coefficient: float = field( - default=1.0e-05, - metadata={"help": "leaky-hmm coefficient for the denominator"}, - ) xent_regularization_coefficient: float = field( default=0.0, metadata={"help": "cross-entropy regularization coefficient"}, ) - output_l2_regularization_coefficient: float = field( - default=0.0, - metadata={"help": "L2 regularization coefficient for the network's output"}, - ) def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.SymbolTable): @@ -62,7 +57,6 @@ def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.S fsa = k2.linear_fsa(word_ids_list) # create an FsaVec from a list of list of word ids num_graph = k2.intersect(fsa, HCL_fst_inv).invert_() - num_graph = k2.add_epsilon_self_loops(num_graph) return num_graph @@ -70,23 +64,23 @@ def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.S class K2LatticeFreeMMICriterion(FairseqCriterion): def __init__( - self, task, sentence_avg, denominator_fst_path, HCL_fst_path, word_symbol_table_path, - leaky_hmm_coefficient, xent_regularization_coefficient, output_l2_regularization_coefficient, + self, task: FairseqTask, sentence_avg: bool, denominator_fst_path: str, + HCL_fst_path: str, word_symbol_table_path: str, xent_regularization_coefficient: float, ): super().__init__(task) self.sentence_avg = sentence_avg self.den_graph = k2.create_fsa_vec( k2.Fsa.from_dict(torch.load(denominator_fst_path)) - ) # has to be FsaVec to be able to intersect with a batch of dense fsas + ) # has to be an FsaVec to be able to intersect with a batch of dense fsas self.den_graph.scores.requires_grad_(False) self.HCL_fst_inv = k2.Fsa.from_dict(torch.load(HCL_fst_path)).invert_() self.symbol_table = k2.SymbolTable.from_file(word_symbol_table_path) - self.leaky_hmm_coefficient = leaky_hmm_coefficient self.xent_regularize = xent_regularization_coefficient - self.output_l2_regularize = output_l2_regularization_coefficient - def forward(self, model, sample, reduce=True): + def forward( + self, model: BaseFairseqModel, sample: List[Dict[str, Any]], reduce: Optional[bool] = True, + ): """Compute the loss for the given sample. Returns a tuple with three elements: @@ -109,7 +103,9 @@ def forward(self, model, sample, reduce=True): } return loss, sample_size, logging_output - def compute_loss(self, net_output, sample, reduce=True): + def compute_loss( + self, net_output: EncoderOut, sample: List[Dict[str, Any]], reduce: Optional[bool] = True, + ): # create the dense fsts from the network's output encoder_out = net_output.encoder_out.transpose(0, 1) # T x B x V -> B x T x V out_lengths = net_output.src_lengths.long() # B @@ -127,29 +123,24 @@ def compute_loss(self, net_output, sample, reduce=True): num_graphs_unrolled = k2.intersect_dense_pruned( num_graphs, dense_fsa_vec, beam=100000, max_active_states=10000, min_active_states=0 ) - num_scores = k2.get_tot_scores(num_graphs_unrolled, log_semiring=False, use_float_scores=True) + num_scores = k2.get_tot_scores(num_graphs_unrolled, log_semiring=True, use_float_scores=True) # denominator computation self.den_graph.to_(encoder_out.device) den_graph_unrolled = k2.intersect_dense_pruned( self.den_graph, dense_fsa_vec, beam=100000, max_active_states=10000, min_active_states=0 ) - den_scores = k2.get_tot_scores(den_graph_unrolled, log_semiring=False, use_float_scores=True) + den_scores = k2.get_tot_scores(den_graph_unrolled, log_semiring=True, use_float_scores=True) # obtain the loss - loss = -num_scores + den_scores + if reduce: + num_scores = num_scores.sum() + den_scores = den_scores.sum() + loss = -num_scores + den_scores # negative log-probs nll_loss = loss.clone().detach() if self.xent_regularize > 0.0: loss -= self.xent_regularize * num_scores - if self.output_l2_regularize > 0.0: - encoder_padding_mask = net_output.encoder_padding_mask - encoder_out_squared = encoder_out.pow(2.0) - if encoder_padding_mask is not None: - pad_mask = encoder_padding_mask.transpose(0, 1).unsqueeze(-1) # T x B -> B x T x 1 - encoder_out_squared.masked_fill_(pad_mask, 0.0) - loss += 0.5 * self.output_l2_regularize * encoder_out_squared.sum() - return loss, nll_loss @classmethod diff --git a/espresso/data/k2_asr_dataset.py b/espresso/data/k2_asr_dataset.py index d67d707ff..6fbda3407 100644 --- a/espresso/data/k2_asr_dataset.py +++ b/espresso/data/k2_asr_dataset.py @@ -37,20 +37,11 @@ def collate( if len(samples) == 0: return {} - def merge(key, pad_to_length=None): - if key == "source": - return speech_utils.collate_frames( - [sample[key] for sample in samples], 0.0, - pad_to_length=pad_to_length, - pad_to_multiple=pad_to_multiple, - ) - else: - raise ValueError("Invalid key.") - id = torch.LongTensor([sample["id"] for sample in samples]) - src_frames = merge( - "source", + src_frames = speech_utils.collate_frames( + [sample["source"] for sample in samples], 0.0, pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + pad_to_multiple=pad_to_multiple, ) # sort by descending source length if pad_to_length is not None: @@ -141,7 +132,7 @@ def __init__( self.shuffle = shuffle self.epoch = 1 self.sizes = ( - np.vstack((self.src_sizes, self.tgt_sizes)).T + np.stack((self.src_sizes, self.tgt_sizes), axis=1) if self.tgt_sizes is not None else self.src_sizes ) @@ -192,18 +183,18 @@ def collater( Returns: dict: a mini-batch with the following keys: - - `id` (LongTensor): example IDs in the original input order - - `utt_id` (List[str]): list of utterance ids - - `nsentences` (int): batch size - - `ntokens` (int): total number of tokens in the batch - - `net_input` (dict): the input to the Model, containing keys: + - `id` -> LongTensor: example IDs in the original input order + - `utt_id` -> List[str]: list of utterance ids + - `nsentences` -> int: batch size + - `ntokens` -> int: total number of tokens in the batch + - `net_input` -> Dict: the input to the Model, containing keys: - - `src_tokens` (FloatTensor): a padded 3D Tensor of features in + - `src_tokens` -> FloatTensor: a padded 3D Tensor of features in the source of shape `(bsz, src_len, feat_dim)`. - - `src_lengths` (IntTensor): 1D Tensor of the unpadded + - `src_lengths` -> IntTensor: 1D Tensor of the unpadded lengths of each source sequence of shape `(bsz)` - - `target` (List[Dict[str, Any]]): a List representing a batch of + - `target` -> List[Dict[str, Any]]: a List representing a batch of supervisions """ return collate( diff --git a/examples/mobvoihotwords/local/generate_graphs.py b/examples/mobvoihotwords/local/generate_graphs.py index 49df9e628..88f16954d 100755 --- a/examples/mobvoihotwords/local/generate_graphs.py +++ b/examples/mobvoihotwords/local/generate_graphs.py @@ -41,14 +41,14 @@ def main(args): except ImportError: raise ImportError("Please install k2 by `pip install k2`") - H = [] - hmm_paths = args.hmm_paths - for hmm in hmm_paths: + hmms = [] + for hmm in args.hmm_paths: with open(hmm, "r", encoding="utf-8") as f: - H.append(k2.Fsa.from_openfst(f.read(), acceptor=False)) - H[-1] = k2.arc_sort(H[-1]) - #TODO: H = fsa.union([H]) - H = k2.arc_sort(H.invert_()) + hmms.append(k2.Fsa.from_openfst(f.read(), acceptor=False)) + hmms[-1] = k2.arc_sort(hmms[-1]) + hmm_vec = k2.create_fsa_vec(hmms) + H = k2.union(hmm_vec) + H_inv = k2.arc_sort(H.invert_()) with open(args.lexicon_fst_path, "r", encoding="utf-8") as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) @@ -62,21 +62,20 @@ def main(args): if hasattr(L, "aux_symbols"): setattr(L, "temp_symbols", L.aux_symbols) delattr(L, "aux_symbols") - HL = k2.intersect(H, L) + HL = k2.intersect(H_inv, L) if hasattr(L, "temp_symbols"): setattr(L, "aux_symbols", L.temp_symbols) delattr(L, "temp_symbols") HL = k2.arc_sort(HL) save_path = os.path.join(args.out_dir, "HL.pt") torch.save(HL.as_dict(), save_path) - logger.info(f"save HL as {save_path}") + logger.info(f"saved the HL fst as {save_path}") - den_graph = k2.intersect(H, phone_lm).invert_() + den_graph = k2.intersect(H_inv, phone_lm).invert_() den_graph = k2.arc_sort(den_graph) - #den_graph = k2.add_epsilon_self_loops(den_graph) save_path = os.path.join(args.out_dir, "denominator.pt") torch.save(den_graph.as_dict(), save_path) - logger.info(f"save the denominator graph as {save_path}") + logger.info(f"saved the denominator graph as {save_path}") if __name__ == "__main__": diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh index 8d8d77d19..b1ff75f61 100755 --- a/examples/mobvoihotwords/run.sh +++ b/examples/mobvoihotwords/run.sh @@ -44,7 +44,7 @@ FREETEXT 1.0 freetext EOF utils/lang/make_lexicon_fst.py --sil-prob 0.5 --sil-phone SIL \ - data/lang/lexiconp.txt > data/lang/L.fst.txt.sym + data/lang/lexiconp.txt > data/lang/L.fst.sym echo "Prepare phones symbol table" cat > data/lang/phones.txt < data/lang/L.fst.txt echo "Prepare HMMs for phones" @@ -74,13 +74,13 @@ EOF id_word1=`cat data/lang/phones.txt | grep "nihaowenwen" | awk '{print $2}'` id_freetext=`cat data/lang/phones.txt | grep "freetext" | awk '{print $2}'` - cat > data/lang/hmm_sil.txt < data/lang/hmm_sil.fst.txt < data/lang/hmm_freetext.txt < data/lang/hmm_freetext.fst.txt < data/lang/hmm_hixiaowen.txt < data/lang/hmm_hixiaowen.fst.txt < data/lang/hmm_nihaowenwen.txt < data/lang/hmm_nihaowenwen.fst.txt < data/lang/phone_lm.txt + cat < data/lang/phone_lm.fsa.txt 0 1 $id_sil 0 5 $id_sil 1 2 $id_word0 @@ -132,8 +132,8 @@ EOF EOF echo "Generate graphs for training" - local/generate_graphs.py --hmm-paths data/lang/hmm_{sil,freetext,hixiaowen,nihaowenwen}.txt \ - --lexicon-fst-path data/lang/L.fst.txt --phone-lm-fsa-path data/lang/phone_lm.txt \ + local/generate_graphs.py --hmm-paths data/lang/hmm_{sil,freetext,hixiaowen,nihaowenwen}.fst.txt \ + --lexicon-fst-path data/lang/L.fst.txt --phone-lm-fsa-path data/lang/phone_lm.fsa.txt \ --out-dir data fi From dc3874af4e55e497991cca4fbaafd1febeed401e Mon Sep 17 00:00:00 2001 From: freewym Date: Mon, 16 Nov 2020 04:04:50 -0500 Subject: [PATCH 115/119] f --- espresso/criterions/k2_lf_mmi_loss.py | 73 +++++++++++------ espresso/data/k2_asr_dataset.py | 15 ++-- espresso/tasks/speech_recognition_hybrid.py | 7 +- espresso/tools/Makefile | 13 --- examples/mobvoihotwords/local/data_prep.py | 82 ++++++++++--------- .../mobvoihotwords/local/generate_graphs.py | 27 +++--- examples/mobvoihotwords/run.sh | 73 +++++++++-------- setup.py | 2 + 8 files changed, 161 insertions(+), 131 deletions(-) diff --git a/espresso/criterions/k2_lf_mmi_loss.py b/espresso/criterions/k2_lf_mmi_loss.py index e221e07d6..10cc153b9 100644 --- a/espresso/criterions/k2_lf_mmi_loss.py +++ b/espresso/criterions/k2_lf_mmi_loss.py @@ -6,18 +6,18 @@ from dataclasses import dataclass, field import logging import math +from omegaconf import II from typing import Any, Dict, List, Optional import torch +from torch import Tensor from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.models import BaseFairseqModel -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.tasks import FairseqTask from fairseq.logging import metrics -from omegaconf import II try: import k2 @@ -46,7 +46,7 @@ class K2LatticeFreeMMICriterionConfig(FairseqDataclass): ) -def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.SymbolTable): +def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.SymbolTable, den_graph=None): word_ids_list = [] for text in texts: filtered_text = [ @@ -56,27 +56,34 @@ def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.S word_ids_list.append(word_ids) fsa = k2.linear_fsa(word_ids_list) # create an FsaVec from a list of list of word ids - num_graph = k2.intersect(fsa, HCL_fst_inv).invert_() - return num_graph + num_graphs = k2.intersect(fsa, HCL_fst_inv).invert_() + # TODO: normalize numerator + if False: #den_graph is not None: + num_graphs = k2.arc_sort(num_graphs) + num_graphs.scores = num_graphs.scores.new_zeros(num_graphs.scores.size()) # zero the score before intersect to avoid double counting + num_graphs = k2.intersect(num_graphs, den_graph) + return num_graphs @register_criterion("k2_lattice_free_mmi", dataclass=K2LatticeFreeMMICriterionConfig) class K2LatticeFreeMMICriterion(FairseqCriterion): - - def __init__( - self, task: FairseqTask, sentence_avg: bool, denominator_fst_path: str, - HCL_fst_path: str, word_symbol_table_path: str, xent_regularization_coefficient: float, - ): + def __init__(self, cfg: K2LatticeFreeMMICriterionConfig, task: FairseqTask): super().__init__(task) - self.sentence_avg = sentence_avg + self.sentence_avg = cfg.sentence_avg self.den_graph = k2.create_fsa_vec( - k2.Fsa.from_dict(torch.load(denominator_fst_path)) + [k2.Fsa.from_dict(torch.load(cfg.denominator_fst_path))] ) # has to be an FsaVec to be able to intersect with a batch of dense fsas + if hasattr(self.den_graph, "aux_labels"): + del self.den_graph.aux_labels + if hasattr(self.den_graph, "aux_symbols"): + del self.den_graph.aux_symbols self.den_graph.scores.requires_grad_(False) - self.HCL_fst_inv = k2.Fsa.from_dict(torch.load(HCL_fst_path)).invert_() - self.symbol_table = k2.SymbolTable.from_file(word_symbol_table_path) - self.xent_regularize = xent_regularization_coefficient + self.den_graph_cpu = self.den_graph.clone() + self.HCL_fst_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(cfg.HCL_fst_path)).invert_()) + self.symbol_table = k2.SymbolTable.from_file(cfg.word_symbol_table_path) + self.xent_regularize = cfg.xent_regularization_coefficient + self.subsampling_factor = None def forward( self, model: BaseFairseqModel, sample: List[Dict[str, Any]], reduce: Optional[bool] = True, @@ -88,6 +95,10 @@ def forward( 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ + if self.subsampling_factor is None: + assert hasattr(model, "output_lengths"), "model should implement the method `output_lengths()`" + self.subsampling_factor = int(round(100.0 / model.output_lengths(100))) + net_output = model(**sample["net_input"]) loss, nll_loss = self.compute_loss(net_output, sample, reduce=reduce) @@ -104,36 +115,48 @@ def forward( return loss, sample_size, logging_output def compute_loss( - self, net_output: EncoderOut, sample: List[Dict[str, Any]], reduce: Optional[bool] = True, + self, net_output: Dict[str, List[Tensor]], sample: List[Dict[str, Any]], reduce: Optional[bool] = True, ): # create the dense fsts from the network's output - encoder_out = net_output.encoder_out.transpose(0, 1) # T x B x V -> B x T x V - out_lengths = net_output.src_lengths.long() # B + encoder_out = net_output["encoder_out"][0].transpose(0, 1) # T x B x V -> B x T x V + if torch.isnan(encoder_out).int().sum().item() > 0 or torch.isinf(encoder_out).int().sum().item() > 0: + print("nan",torch.isnan(encoder_out).int().sum().item(), "inf", torch.isinf(encoder_out).int().sum().item()) + encoder_out = encoder_out.clamp(-30, 30) # clamp to avoid numerical overflows + out_lengths = net_output["src_lengths"][0] # B supervision_segments = torch.stack( # seq_index, start_frame, lengths - (sample["target"]["sequence_idx"], sample["target"]["start_frame"], out_lengths), + ( + sample["target"]["sequence_idx"], + torch.floor_divide(sample["target"]["start_frame"], self.subsampling_factor), + out_lengths + ), dim=1 - ) + ).int().cpu() dense_fsa_vec = k2.DenseFsaVec(encoder_out, supervision_segments) # numerator computation - num_graphs = create_numerator_graphs(sample["target"]["text"], self.HCL_fst_inv, self.symbol_table) - num_graphs.to_(encoder_out.device) + num_graphs = create_numerator_graphs( + sample["target"]["text"], self.HCL_fst_inv, self.symbol_table, den_graph=self.den_graph_cpu + ).to(encoder_out.device) num_graphs.scores.requires_grad_(False) num_graphs_unrolled = k2.intersect_dense_pruned( - num_graphs, dense_fsa_vec, beam=100000, max_active_states=10000, min_active_states=0 + num_graphs, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000 ) num_scores = k2.get_tot_scores(num_graphs_unrolled, log_semiring=True, use_float_scores=True) # denominator computation - self.den_graph.to_(encoder_out.device) + self.den_graph = self.den_graph.to(encoder_out.device) den_graph_unrolled = k2.intersect_dense_pruned( - self.den_graph, dense_fsa_vec, beam=100000, max_active_states=10000, min_active_states=0 + self.den_graph, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000 ) den_scores = k2.get_tot_scores(den_graph_unrolled, log_semiring=True, use_float_scores=True) # obtain the loss if reduce: + if torch.isnan(num_scores).int().sum().item() > 0 or torch.isinf(num_scores).int().sum().item() > 0: + print("num nan", torch.isnan(num_scores).int().sum().item(), "inf", torch.isinf(num_scores).int().sum().item()) + if torch.isnan(den_scores).int().sum().item() > 0 or torch.isinf(den_scores).int().sum().item() > 0: + print("den nan", torch.isnan(den_scores).int().sum().item(), "inf", torch.isinf(den_scores).int().sum().item()) num_scores = num_scores.sum() den_scores = den_scores.sum() loss = -num_scores + den_scores # negative log-probs diff --git a/espresso/data/k2_asr_dataset.py b/espresso/data/k2_asr_dataset.py index 6fbda3407..666fceb17 100644 --- a/espresso/data/k2_asr_dataset.py +++ b/espresso/data/k2_asr_dataset.py @@ -16,10 +16,10 @@ import espresso.tools.utils as speech_utils try: - # TODO use pip install once it's available - from espresso.tools.lhotse.lhotse import CutSet + from lhotse import CutSet + from lhotse.utils import compute_num_frames except ImportError: - raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + raise ImportError("Please install Lhotse by `pip install lhotse`") def collate( @@ -53,11 +53,12 @@ def collate( src_lengths, sort_order = src_lengths.sort(descending=True) id = id.index_select(0, sort_order) utt_id = [samples[i]["utt_id"] for i in sort_order.numpy()] + reco_id = [samples[i]["reco_id"] for i in sort_order.numpy()] src_frames = src_frames.index_select(0, sort_order) ntokens = src_lengths.sum().item() target = None - if samples[0].get("target", None) is not None and len(samples[0].target) > 0: + if samples[0].get("target", None) is not None and len(samples[0]["target"]) > 0: # reorder the list of samples to make things easier # (no need to reorder every element in target) samples = [samples[i] for i in sort_order.numpy()] @@ -83,6 +84,7 @@ def update(d: Dict, **kwargs) -> Dict: batch = { "id": id, "utt_id": utt_id, + "reco_id": reco_id, "nsentences": len(samples), "ntokens": ntokens, "net_input": { @@ -147,13 +149,14 @@ def __getitem__(self, index): example = { "id": index, "utt_id": cut_id, + "reco_id": cut.recording_id, "source": features, "target": [ { "sequence_idx": index, "text": sup.text, - "start_frame": round(sup.start / cut.frame_shift), - "num_frames": round(sup.duration / cut.frame_shift), + "start_frame": compute_num_frames(sup.start, cut.frame_shift), + "num_frames": compute_num_frames(sup.duration, cut.frame_shift), } # CutSet's supervisions can exceed the cut, when the cut starts/ends in the middle # of a supervision (they would have relative times e.g. -2 seconds start, meaning diff --git a/espresso/tasks/speech_recognition_hybrid.py b/espresso/tasks/speech_recognition_hybrid.py index 711b561eb..e718c3141 100644 --- a/espresso/tasks/speech_recognition_hybrid.py +++ b/espresso/tasks/speech_recognition_hybrid.py @@ -150,12 +150,11 @@ class SpeechRecognitionHybridConfig(FairseqDataclass): def get_k2_dataset_from_json(data_path, split, shuffle=True, pad_to_multiple=1, seed=1): try: - # TODO use pip install once it's available - from espresso.tools.lhotse.lhotse import CutSet + from lhotse import CutSet except ImportError: - raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + raise ImportError("Please install Lhotse by `pip install lhotse`") - data_json_path = os.path.join(data_path, "cuts_{}.json".format(split)) + data_json_path = os.path.join(data_path, "cuts_{}.json.gz".format(split)) if not os.path.isfile(data_json_path): raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) diff --git a/espresso/tools/Makefile b/espresso/tools/Makefile index a95186e58..80ae1eb7b 100644 --- a/espresso/tools/Makefile +++ b/espresso/tools/Makefile @@ -1,6 +1,5 @@ KALDI = PYTHON_DIR = /export/b03/ywang/anaconda3/bin -CMAKE = /home/ywang/cmake-3.18.4-Linux-x86_64/bin/cmake CXX ?= g++ @@ -31,8 +30,6 @@ kaldi: endif clean: openfst_cleaned - rm -rf k2 - rm -rf lhotse rm -rf pychain rm -rf kaldi @@ -82,13 +79,3 @@ pychain: export PATH=$(PYTHON_DIR):$$PATH && \ cd pychain/openfst_binding && python3 setup.py install && \ cd ../pytorch_binding && python3 setup.py install - -.PHONY: lhotse -lhotse: - test -d lhotse || git clone https://github.com/lhotse-speech/lhotse.git - export PATH=$(PYTHON_DIR):$$PATH && cd lhotse && pip install -e . - -.PHONY: k2 -k2: - test -d k2 || git clone https://github.com/k2-fsa/k2.git - cd k2 && mkdir -p build && cd build && $(CMAKE) .. && $(MAKE) diff --git a/examples/mobvoihotwords/local/data_prep.py b/examples/mobvoihotwords/local/data_prep.py index df64dbd04..916983363 100755 --- a/examples/mobvoihotwords/local/data_prep.py +++ b/examples/mobvoihotwords/local/data_prep.py @@ -16,19 +16,29 @@ import numpy as np +import torch + from fairseq.data.data_utils import numpy_seed try: - # TODO use pip install once it's available - from espresso.tools.lhotse.lhotse import ( + from lhotse import ( CutSet, Mfcc, MfccConfig, LilcomFilesWriter, RecordingSet, SupervisionSet ) - from espresso.tools.lhotse.lhotse.augmentation import SoxEffectTransform, RandomValue - from espresso.tools.lhotse.lhotse.manipulation import combine - from espresso.tools.lhotse.lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords + from lhotse.augmentation import SoxEffectTransform, RandomValue + from lhotse.manipulation import combine + from lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords except ImportError: - raise ImportError("Please install Lhotse by `make lhotse` after entering espresso/tools") + raise ImportError("Please install Lhotse by `pip install lhotse`") + + +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() because it needs to take effect +# even when we are not invoking the main (notice: "spawn" is the method used +# in multiprocessing, which is to get around some problems with torchaudio's invocation of +# sox). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) logging.basicConfig( @@ -72,20 +82,18 @@ def main(args): download_and_untar(root_dir) logger.info(f"Prepare the manifests") - partitions = ["train", "dev", "test"] - if all( - (output_dir / f"{key}_{part}.json").is_file() - for key in ["recordings", "supervisions"] for part in partitions - ): - logger.info(f"All the manifests files are found in {output_dir}. Load from them directly") + data_parts = ["train", "dev", "test"] + try: mobvoihotwords_manifests = defaultdict(dict) - for part in partitions: - mobvoihotwords_manifests[part] = { - "recordings": RecordingSet.from_json(output_dir / f"recordings_{part}.json"), - "supervisions": SupervisionSet.from_json(output_dir / f"supervisions_{part}.json") - } - else: - logger.info("It may take long time") + for part in data_parts: + mobvoihotwords_manifests[part]["recordings"] = RecordingSet.from_json( + output_dir / f"recordings_{part}.json" + ) + mobvoihotwords_manifests[part]["supervisions"] = SupervisionSet.from_json( + output_dir / f"supervisions_{part}.json" + ) + except Exception as e: + logger.warning("Mobvoihotwords manifests not found on disk, preparing them from scratch: " + str(e)) mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) logger.info( "train/dev/test size: {}/{}/{}".format( @@ -96,16 +104,23 @@ def main(args): ) # Data augmentation + num_workers = min(args.num_workers, os.cpu_count()) np.random.seed(args.seed) # equivalent to Kaldi's mfcc_hires config mfcc = Mfcc(config=MfccConfig(num_mel_bins=40, num_ceps=40, low_freq=20, high_freq=-400)) for part, manifests in mobvoihotwords_manifests.items(): + logger.info(part) + json_path = output_dir / f"cuts_{part}.json.gz" + if json_path.is_file(): + logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") + continue + cut_set = CutSet.from_manifests( recordings=manifests["recordings"], supervisions=manifests["supervisions"], ) sampling_rate = next(iter(cut_set)).sampling_rate - with ProcessPoolExecutor(args.num_workers, mp_context=multiprocessing.get_context("spawn")) as ex: + with ProcessPoolExecutor(num_workers, mp_context=multiprocessing.get_context("spawn")) as ex: if part == "train": # split negative recordings into smaller chunks with lengths sampled from # length distribution of positive recordings @@ -195,26 +210,19 @@ def main(args): # combine the clean and augmented sets together logger.info(f"Combine all the features above") - cut_set = combine( - cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9 - ) + cut_set = combine(cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9) else: # no augmentations for dev and test sets logger.info(f"extract features for '{part}' set") - json_path = output_dir / f"cuts_{part}.json.gz" - if json_path.is_file(): - logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") - cut_set = CutSet.from_json(json_path) - else: - with LilcomFilesWriter(f"{output_dir}/feats_{part}") as storage: - cut_set = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augmenter=None, - executor=ex, - ) + with LilcomFilesWriter(f"{output_dir}/feats_{part}") as storage: + cut_set = cut_set.compute_and_store_features( + extractor=mfcc, + storage=storage, + augment_fn=None, + executor=ex, + ) - mobvoihotwords_manifests[part]["cuts"] = cut_set - cut_set.to_json(output_dir / f"cuts_{part}.json.gz") + mobvoihotwords_manifests[part]["cuts"] = cut_set + cut_set.to_json(output_dir / f"cuts_{part}.json.gz") def get_positive_durations(sup_set: SupervisionSet) -> List[float]: diff --git a/examples/mobvoihotwords/local/generate_graphs.py b/examples/mobvoihotwords/local/generate_graphs.py index 88f16954d..db981cc7d 100755 --- a/examples/mobvoihotwords/local/generate_graphs.py +++ b/examples/mobvoihotwords/local/generate_graphs.py @@ -47,25 +47,32 @@ def main(args): hmms.append(k2.Fsa.from_openfst(f.read(), acceptor=False)) hmms[-1] = k2.arc_sort(hmms[-1]) hmm_vec = k2.create_fsa_vec(hmms) - H = k2.union(hmm_vec) + H = k2.closure(k2.union(hmm_vec)) H_inv = k2.arc_sort(H.invert_()) with open(args.lexicon_fst_path, "r", encoding="utf-8") as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) - L = k2.arc_sort(L.invert_()).invert_() # sort on olabels + L = k2.arc_sort(L) - with open(args.phone_lm_fst_path, "r", encoding="utf-8") as f: + with open(args.phone_lm_fsa_path, "r", encoding="utf-8") as f: phone_lm = k2.Fsa.from_openfst(f.read(), acceptor=True) + assert not hasattr(phone_lm, "aux_labels") phone_lm = k2.arc_sort(phone_lm) # emulate composition - if hasattr(L, "aux_symbols"): - setattr(L, "temp_symbols", L.aux_symbols) - delattr(L, "aux_symbols") - HL = k2.intersect(H_inv, L) - if hasattr(L, "temp_symbols"): - setattr(L, "aux_symbols", L.temp_symbols) - delattr(L, "temp_symbols") + if hasattr(L, "aux_labels"): + L.temp_labels = L.aux_labels + del L.aux_labels + if hasattr(L, "aux_symbols"): + L.temp_symbols = L.aux_symbols + del L.aux_symbols + HL = k2.intersect(H_inv, L).invert_() + if hasattr(HL, "temp_labels"): + HL.aux_labels = HL.temp_labels + del HL.temp_labels + if hasattr(HL, "temp_symbols"): + HL.aux_symbols = HL.temp_symbols + del HL.temp_symbols HL = k2.arc_sort(HL) save_path = os.path.join(args.out_dir, "HL.pt") torch.save(HL.as_dict(), save_path) diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh index b1ff75f61..72ebfd19f 100755 --- a/examples/mobvoihotwords/run.sh +++ b/examples/mobvoihotwords/run.sh @@ -22,7 +22,7 @@ wake_word1="NihaoWenwen" . ./cmd.sh . ./utils/parse_options.sh -dir=exp/tdnn_k2_${affix:+_$affix} +dir=exp/tdnn_k2${affix:+_$affix} if [ ${stage} -le 0 ]; then echo "Stage 0: Data Preparation" @@ -75,51 +75,51 @@ EOF id_freetext=`cat data/lang/phones.txt | grep "freetext" | awk '{print $2}'` cat > data/lang/hmm_sil.fst.txt < data/lang/hmm_freetext.fst.txt < data/lang/hmm_hixiaowen.fst.txt < data/lang/hmm_nihaowenwen.fst.txt < data/lang/phone_lm.fsa.txt 0 1 $id_sil -0 5 $id_sil +0 7 $id_sil 1 2 $id_word0 2 3 $id_sil 1 4 $id_word1 @@ -132,7 +132,8 @@ EOF EOF echo "Generate graphs for training" - local/generate_graphs.py --hmm-paths data/lang/hmm_{sil,freetext,hixiaowen,nihaowenwen}.fst.txt \ + log_file=data/log/generate_graphs.log + $train_cmd $log_file local/generate_graphs.py --hmm-paths data/lang/hmm_{sil,freetext,hixiaowen,nihaowenwen}.fst.txt \ --lexicon-fst-path data/lang/L.fst.txt --phone-lm-fsa-path data/lang/phone_lm.fsa.txt \ --out-dir data fi @@ -154,7 +155,7 @@ if [ ${stage} -le 2 ]; then [ -f $dir/checkpoint_last.pt ] && log_file="-a $log_file" update_freq=1 CUDA_VISIBLE_DEVICES=$free_gpu speech_train.py data --task speech_recognition_hybrid --seed 1 \ - --log-interval $((1500/ngpus/update_freq)) --log-format simple \ + --log-interval $((1500/ngpus/update_freq)) --log-format simple --use-k2-dataset \ --num-workers 0 --data-buffer-size 0 --max-tokens 25600 --batch-size 128 --empty-cache-freq 50 \ --valid-subset $valid_subset --batch-size-valid 128 --ddp-backend no_c10d --update-freq $update_freq \ --distributed-world-size $ngpus --arch speech_tdnn_mobvoi \ @@ -162,11 +163,11 @@ if [ ${stage} -le 2 ]; then --lr-scheduler reduce_lr_on_plateau_v2 --lr-shrink 0.5 \ --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((1500/ngpus/update_freq)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ - --criterion k2_lattice_free_mmi --num-targets $num_targets --word-symbol-path data/lang/words.txt \ + --criterion k2_lattice_free_mmi --num-targets $num_targets --word-symbol-table-path data/lang/words.txt \ --denominator-fst-path data/denominator.pt --HCL-fst-path data/HL.pt \ --max-source-positions 9999 --max-target-positions 9999 $opts 2>&1 | tee $log_file fi if [ ${stage} -le 3 ]; then - echo "Stage 3: Decoding" + echo "Stage 3: Dump Posteriors for Evaluation" fi diff --git a/setup.py b/setup.py index a587e15b9..42a0d145c 100644 --- a/setup.py +++ b/setup.py @@ -187,6 +187,8 @@ def do_setup(package_data): "hydra-core<1.1", "omegaconf<2.1", "kaldi_io", + "k2", + "lhotse", 'numpy<1.20.0; python_version<"3.7"', 'numpy; python_version>="3.7"', "regex", From 63ceaf2ae881e890094517a7fb71b1c4dadfd607 Mon Sep 17 00:00:00 2001 From: freewym Date: Fri, 27 Nov 2020 01:53:58 -0500 Subject: [PATCH 116/119] decoding related --- .../local/create_decoding_graph.py | 63 +++++++++++++ .../mobvoihotwords/local/decode_best_path.py | 74 +++++++++++++++ examples/mobvoihotwords/local/evaluate.py | 90 +++++++++++++++++++ examples/mobvoihotwords/run.sh | 74 +++++++++++++++ 4 files changed, 301 insertions(+) create mode 100755 examples/mobvoihotwords/local/create_decoding_graph.py create mode 100755 examples/mobvoihotwords/local/decode_best_path.py create mode 100755 examples/mobvoihotwords/local/evaluate.py diff --git a/examples/mobvoihotwords/local/create_decoding_graph.py b/examples/mobvoihotwords/local/create_decoding_graph.py new file mode 100755 index 000000000..1df55aff3 --- /dev/null +++ b/examples/mobvoihotwords/local/create_decoding_graph.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import sys + +import torch + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("mobvoihotwords.create_decoding_graph") + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Create the decoding graph for decoding" + ) + # fmt: off + parser.add_argument("--HCL-fst-path", type=str, help="path to the HCL fst file (torch_saved)", required=True) + parser.add_argument("--lm-fsa-path", type=str, help="path to the LM fsa (openfst text format or torch saved)", required=True) + parser.add_argument("--out-dir", type=str, default="data", help="directory to save the decoding graph") + # fmt: on + + return parser + + +def main(args): + try: + import k2 + except ImportError: + raise ImportError("Please install k2 by `pip install k2`") + + HCL_inv = k2.Fsa.from_dict(torch.load(args.HCL_fst_path)).invert_() + HCL_inv = k2.arc_sort(HCL_inv) + + if args.lm_fsa_path[-3:] == ".pt": + G = k2.Fsa.from_dict(torch.load(args.lm_fsa_path)) + else: + with open(args.lm_fsa_path, "r", encoding="utf-8") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=True) + assert not hasattr(G, "aux_labels") + G = k2.arc_sort(G) + + decoding_graph = k2.intersect(HCL_inv, G).invert_() + save_path = os.path.join(args.out_dir, "HCLG.pt") + torch.save(decoding_graph.as_dict(), save_path) + logger.info(f"saved the decoding graph as {save_path}") + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/mobvoihotwords/local/decode_best_path.py b/examples/mobvoihotwords/local/decode_best_path.py new file mode 100755 index 000000000..e3837704e --- /dev/null +++ b/examples/mobvoihotwords/local/decode_best_path.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import sys + +import torch + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("mobvoihotwords.decode_best_path") + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Decode by finding the best path" + ) + # fmt: off + parser.add_argument("--beam", type=float, default=10.0, help="decoding beam") + parser.add_argument("--word-symbol-table", type=str, help="path to the HCL fst file (torch_saved)", required=True) + parser.add_argument("decoding_graph", type=str, default="data", help="path to the decoding graph") + parser.add_argument("net_output", type=str, help="path to the network output file for acoustic scores") + parser.add_argument("hyp_file", type=str, help="path to the resulting hypotheses file") + # fmt: on + + return parser + + +def main(args): + try: + import k2 + except ImportError: + raise ImportError("Please install k2 by `pip install k2`") + try: + import kaldi_io + except ImportError: + raise ImportError("Please install kaldi_io by `pip install kaldi_io`") + + symbol_table = k2.SymbolTable.from_file(args.word_symbol_table) + graph = k2.Fsa.from_dict(torch.load(args.args.decoding_graph)) + graph.scores.requires_grad_(False) + + num_processed = 0 + with open(args.net_output, "r", encoding="utf-8") as f_in, open(args.hyp_file, "r", encoding="utf-8") as f_out: + for line in f_in: + utt_id, rxfile = line.strip().split(maxsplit=1) + net_output = torch.from_numpy(kaldi_io.read_mat(rxfile)).unsqueeze(0) # 1 x T x V + supervision_segments = net_output.new_tensor([0, 0, net_output.size(0)], dtype=torch.int).unsqueeze(0) # 1 x 3 + dense_fsa_vec = k2.DenseFsaVec(net_output, supervision_segments) + graph_unrolled = k2.intersect_dense_pruned( + graph, dense_fsa_vec, search_beam=args.beam, output_beam=15.0, min_active_states=0, max_active_states=10000 + ) + best_path = k2.shortest_path(graph_unrolled, use_float_scores=True) + hyp = [symbol_table._id2sym[x.item()] for x in best_path[0].aux_labels if x > 0] + print(utt_id, hyp, file=f_out) + num_processed += 1 + + logger.info(f"Processed {num_processed} utterances") + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/mobvoihotwords/local/evaluate.py b/examples/mobvoihotwords/local/evaluate.py new file mode 100755 index 000000000..3717c620f --- /dev/null +++ b/examples/mobvoihotwords/local/evaluate.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) Yiming Wang +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +import sys + +import torch + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("mobvoihotwords.evaluate") + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Evaluate by calculating detection metrics" + ) + # fmt: off + parser.add_argument("--wake-word", type=str, help="wake word to be treated as positive", required=True) + parser.add_argument("supervsion_file", type=str, help="path to the supervision set file") + parser.add_argument("hyp_file", type=str, help="path to the resulting hypotheses file") + parser.add_argument("result_file", type=str, help="path to the result file") + # fmt: on + + return parser + + +def main(args): + try: + from lhotse import SupervisionSet + except ImportError: + raise ImportError("Please install Lhotse by `pip install lhotse`") + + supervisions = SupervisionSet.from_json(args.recording_file) # one and only one supervision segment per recording + neg_dur = sum(sup.duration for sup in supervisions if sup.text != args.wake_word) + ref = [(sup.recording_id, sup.text) for sup in supervisions] + + hyp = {} + with open(args.hyp_file, "r", encoding="utf-8") as f: + for line in f: + split_line = line.strip().split(maxsplit=1) + hyp[split_line[0]] = split_line[1] if len(split_line) == 2 else "" + + if len(ref) != len(hyp): + logger.warning("The lengths of reference and hypothesis do not match. ref: {} vs hyp: {}.".format(len(ref), len(hyp))) + + TP = TN = FP = FN = 0.0 + for i in range(len(ref)): + if ref[i][0] not in hyp: + logger.warning("reference {} does not exist in hypothesis.".format(ref[i][0])) + continue + if ref[i][1] == args.wake_word: + if args.wake_word in hyp[ref[i][0]]: + TP += 1.0 + else: + FN += 1.0 + else: + if args.wake_word in hyp[ref[i][0]]: + FP += 1.0 + else: + TN += 1.0 + precision = TP / (TP + FP) if TP + FP > 0 else 0.0 + recall = TP / (TP + FN) if TP + FN > 0 else 0.0 + false_positive_rate = FP / (FP + TN) if FP + TN > 0 else 0.0 + false_negative_rate = FN / (FN + TP) if FN + TP > 0 else 0.0 + false_alarms_per_hour = FP / (neg_dur / 3600) if neg_dur > 0.0 else 0.0 + + with open(args.result_file, "w", encoding="utf-8") as f: + print( + "precision: {:.5f} recall: {:.5f} FPR: {:.5f} FNR: {:.5f} FP per hour: {:.5f} total: {:d}".format( + precision, recall, false_positive_rate, false_negative_rate, false_alarms_per_hour, TP + TN + FP + FN + ), + file=f + ) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh index 72ebfd19f..4c5bf7773 100755 --- a/examples/mobvoihotwords/run.sh +++ b/examples/mobvoihotwords/run.sh @@ -170,4 +170,78 @@ fi if [ ${stage} -le 3 ]; then echo "Stage 3: Dump Posteriors for Evaluation" + path=$dir/$checkpoint + for dataset in $test_set; do + mkdir -p $dir/decode_$dataset/log + log_file=$dir/decode_$dataset/log/dump_posteriors.log + $cuda_cmd $log_file dump_posteriors.py data --use-k2-dataset \ + --task speech_recognition_hybrid --max-tokens 25600 --max-sentences 128 \ + --num-shards 1 --shard-id 0 --num-targets $num_targets --gen-subset $dataset \ + --max-source-positions 9999 --path $path \ + \| copy-matrix ark:- ark,scp:$dir/decode_$dataset/posteriors.ark,$dir/decode_$dataset/posteriors.scp || exit 1; + echo "log saved in $log_file" + done +fi + +if [ ${stage} -le 4 ]; then + echo "Stage 7: Decoding" + lang_test=data/lang_test + rm -rf $lang_test + cp -r data/lang $lang_test + utils/lang/make_lexicon_fst.py --sil-prob 0.0 --sil-phone SIL $lang_test/lexiconp.txt > $lang_test/L.fst.sym + utils/sym2int.pl -f 3 $lang_test/phones.txt <$lang_test/L.fst.sym - | \ + utils/sym2int.pl -f 4 $lang_test/words.txt - > $lang_test/L.fst.txt + + for wake_word in $wake_word0 $wake_word1; do + if [[ "$wake_word" == "$wake_word0" ]]; then + wake_word0_cost_range="-1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0" + wake_word1_cost_range="0.0" + else + wake_word0_cost_range="0.0" + wake_word1_cost_range="-1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0" + fi + for wake_word0_cost in $wake_word0_cost_range; do + for wake_word1_cost in $wake_word1_cost_range; do + sil_id=`cat $lang_test/words.txt | grep "" | awk '{print $2}'` + freetext_id=`cat $lang_test/words.txt | grep "FREETEXT" | awk '{print $2}'` + id0=`cat $lang_test/words.txt | grep $wake_word0 | awk '{print $2}'` + id1=`cat $lang_test/words.txt | grep $wake_word1 | awk '{print $2}'` + mkdir -p $lang_test/lm + cat < $lang_test/lm/fsa.txt +0 1 $sil_id +0 4 $sil_id 7.0 +1 4 $freetext_id 0.0 +4 0 $sil_id 0.0 +1 2 $id0 $wake_word0_cost +1 3 $id1 $wake_word1_cost +2 0 $sil_id +3 0 $sil_id +0 +EOF + local/create_decoding_graph.py --HCL-fst-path data/HL.pt --lm-fsa-path $lang_test/lm/fsa.txt $lang_test/graph || exit 1; + + rm $dir/.error 2>/dev/null || true + for dataset in $test_set; do + ( + nj=30 + score_dir=$dir/decode_$dataset/score_${wake_word}_${wake_word0_cost}_${wake_word1_cost} + mkdir -p $score_dir + $decode_cmd $dir/decode_$dataset/log/decode_${wake_word}.log \ + local/decode_best_path.py --beam=10 --word-symbol-table $lang_test/words.txt \ + $lang_test/graph/HCLG.pt $dir/decode_$dataset/posteriors.scp $score_dir/hyp.txt + local/evaluate.py --wake-word $wake_word \ + data/supervisions_${dataset}.json $score_dir/hyp.txt $score_dir/metrics + ) || touch $dir/.error & + done + wait + [ -f $dir/.error ] && echo "$0: there was a problem while decoding" && exit 1 + done + done + done + for dataset in $test_set; do + for wake_word in $wake_word0 $wake_word1; do + echo "Results on $dataset set with wake word ${wake_word}:" + cat $dir/decode_$dataset/score_${wake_word}_*/metrics + done + done fi From 66d84af940ce1afee9dd775429150d21b053efb0 Mon Sep 17 00:00:00 2001 From: freewym Date: Sat, 26 Dec 2020 02:10:02 -0500 Subject: [PATCH 117/119] some changes --- espresso/criterions/k2_lf_mmi_loss.py | 75 ++++++----- examples/mobvoihotwords/conf/gpu.conf | 2 +- .../local/create_decoding_graph.py | 15 +-- examples/mobvoihotwords/local/data_prep.py | 119 +++++------------- .../mobvoihotwords/local/decode_best_path.py | 12 +- .../mobvoihotwords/local/generate_graphs.py | 64 +++++----- examples/mobvoihotwords/run.sh | 23 ++-- 7 files changed, 135 insertions(+), 175 deletions(-) diff --git a/espresso/criterions/k2_lf_mmi_loss.py b/espresso/criterions/k2_lf_mmi_loss.py index 10cc153b9..c8b213e63 100644 --- a/espresso/criterions/k2_lf_mmi_loss.py +++ b/espresso/criterions/k2_lf_mmi_loss.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field +from functools import lru_cache import logging import math from omegaconf import II @@ -31,11 +32,11 @@ @dataclass class K2LatticeFreeMMICriterionConfig(FairseqDataclass): sentence_avg: bool = II("optimization.sentence_avg") - denominator_fst_path: str = field( - default="???", metadata={"help": "path to the denominator fst file (torch saved)"} + denominator_graph_path: str = field( + default="???", metadata={"help": "path to the denominator graph file (torch saved)"} ) - HCL_fst_path: str = field( - default="???", metadata={"help": "path to the HCL fst file (torch saved)"} + HCL_inv_path: str = field( + default="???", metadata={"help": "path to the HCL_inv fst file (torch saved)"} ) word_symbol_table_path: str = field( default="???", metadata={"help": "path to the word symbol table file"} @@ -46,22 +47,34 @@ class K2LatticeFreeMMICriterionConfig(FairseqDataclass): ) -def create_numerator_graphs(texts: List[str], HCL_fst_inv: k2.Fsa, symbols: k2.SymbolTable, den_graph=None): - word_ids_list = [] - for text in texts: - filtered_text = [ - word if word in symbols._sym2id else "" for word in text.split(" ") - ] +def compile_numerator_graphs( + texts: List[str], symbols: k2.SymbolTable, HCL_inv: k2.Fsa, + unk_str: Optional[str] = "UNK", den_graph: Optional[k2.Fsa] = None +): + assert len(den_graph.shape) == 2 + + @lru_cache(maxsize=100000) + def compile_one_and_cache(text: str) -> k2.Fsa: + filtered_text = [token if token in symbols._sym2id else unk_str for token in text.split(" ")] word_ids = [symbols.get(word) for word in filtered_text] - word_ids_list.append(word_ids) - - fsa = k2.linear_fsa(word_ids_list) # create an FsaVec from a list of list of word ids - num_graphs = k2.intersect(fsa, HCL_fst_inv).invert_() - # TODO: normalize numerator - if False: #den_graph is not None: - num_graphs = k2.arc_sort(num_graphs) - num_graphs.scores = num_graphs.scores.new_zeros(num_graphs.scores.size()) # zero the score before intersect to avoid double counting - num_graphs = k2.intersect(num_graphs, den_graph) + fsa = k2.linear_fsa(word_ids) + #if H_inv is not None and L_inv is not None: + # LG = k2.connect(k2.intersect(fsa, L_inv)).invert_() + # del LG.aux_labels + # num_graph = k2.arc_sort(k2.invert(k2.connect(k2.intersect(H_inv, LG)))) + #else: + # assert HCL_inv is not None + num_graph = k2.invert(k2.connect(k2.intersect(fsa, HCL_inv))) + if den_graph is not None: + num_graph = k2.arc_sort(num_graph) + # zero the score before intersect to avoid double counting + num_graph.scores = num_graph.scores.new_zeros(num_graph.scores.size()) + # treat epsilon as normal labels, i.e., blanks + num_graph = k2.connect(k2.intersect(num_graph, den_graph, treat_epsilons_specially=False)) + del num_graph.aux_labels + return num_graph + + num_graphs = k2.create_fsa_vec([compile_one_and_cache(text) for text in texts]) return num_graphs @@ -70,17 +83,19 @@ class K2LatticeFreeMMICriterion(FairseqCriterion): def __init__(self, cfg: K2LatticeFreeMMICriterionConfig, task: FairseqTask): super().__init__(task) + self.unk_str = task.target_dictionary.unk_string() if task.target_dictionary is not None else "UNK" self.sentence_avg = cfg.sentence_avg self.den_graph = k2.create_fsa_vec( - [k2.Fsa.from_dict(torch.load(cfg.denominator_fst_path))] + [k2.Fsa.from_dict(torch.load(cfg.denominator_graph_path))] ) # has to be an FsaVec to be able to intersect with a batch of dense fsas if hasattr(self.den_graph, "aux_labels"): del self.den_graph.aux_labels - if hasattr(self.den_graph, "aux_symbols"): - del self.den_graph.aux_symbols self.den_graph.scores.requires_grad_(False) - self.den_graph_cpu = self.den_graph.clone() - self.HCL_fst_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(cfg.HCL_fst_path)).invert_()) + self.den_graph_cpu = self.den_graph[0].clone() # to be intersect with a individual numerator fsa + self.HCL_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(cfg.HCL_inv_path))) + #self.H_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(cfg.H_inv_path))) + #with open(cfg.L_path, "r", encoding="utf-8") as f: + # self.L_inv = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=False).invert_()) self.symbol_table = k2.SymbolTable.from_file(cfg.word_symbol_table_path) self.xent_regularize = cfg.xent_regularization_coefficient self.subsampling_factor = None @@ -97,7 +112,7 @@ def forward( """ if self.subsampling_factor is None: assert hasattr(model, "output_lengths"), "model should implement the method `output_lengths()`" - self.subsampling_factor = int(round(100.0 / model.output_lengths(100))) + self.subsampling_factor = int(round(120.0 / model.output_lengths(120))) net_output = model(**sample["net_input"]) loss, nll_loss = self.compute_loss(net_output, sample, reduce=reduce) @@ -131,25 +146,25 @@ def compute_loss( out_lengths ), dim=1 - ).int().cpu() + ).int().cpu() # assume batched in descending order of lengths dense_fsa_vec = k2.DenseFsaVec(encoder_out, supervision_segments) # numerator computation - num_graphs = create_numerator_graphs( - sample["target"]["text"], self.HCL_fst_inv, self.symbol_table, den_graph=self.den_graph_cpu + num_graphs = compile_numerator_graphs( + sample["target"]["text"], self.symbol_table, HCL_inv=self.HCL_inv, unk_str=self.unk_str, den_graph=self.den_graph_cpu ).to(encoder_out.device) num_graphs.scores.requires_grad_(False) num_graphs_unrolled = k2.intersect_dense_pruned( num_graphs, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000 ) - num_scores = k2.get_tot_scores(num_graphs_unrolled, log_semiring=True, use_float_scores=True) + num_scores = k2.get_tot_scores(num_graphs_unrolled, log_semiring=True, use_double_scores=False) # denominator computation self.den_graph = self.den_graph.to(encoder_out.device) den_graph_unrolled = k2.intersect_dense_pruned( self.den_graph, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000 ) - den_scores = k2.get_tot_scores(den_graph_unrolled, log_semiring=True, use_float_scores=True) + den_scores = k2.get_tot_scores(den_graph_unrolled, log_semiring=True, use_double_scores=False) # obtain the loss if reduce: diff --git a/examples/mobvoihotwords/conf/gpu.conf b/examples/mobvoihotwords/conf/gpu.conf index 5cc94adf2..f5f7e0ff1 100644 --- a/examples/mobvoihotwords/conf/gpu.conf +++ b/examples/mobvoihotwords/conf/gpu.conf @@ -7,4 +7,4 @@ option num_threads=1 # Do not add anything to qsub_opts option max_jobs_run=* -tc $0 default gpu=0 option gpu=0 -option gpu=* -l 'hostname=c*,gpu=$0' -q g.q +option gpu=* -l 'hostname=c*&!c2*,gpu=$0' -q g.q diff --git a/examples/mobvoihotwords/local/create_decoding_graph.py b/examples/mobvoihotwords/local/create_decoding_graph.py index 1df55aff3..23e00d86d 100755 --- a/examples/mobvoihotwords/local/create_decoding_graph.py +++ b/examples/mobvoihotwords/local/create_decoding_graph.py @@ -26,8 +26,8 @@ def get_parser(): description="Create the decoding graph for decoding" ) # fmt: off - parser.add_argument("--HCL-fst-path", type=str, help="path to the HCL fst file (torch_saved)", required=True) - parser.add_argument("--lm-fsa-path", type=str, help="path to the LM fsa (openfst text format or torch saved)", required=True) + parser.add_argument("--HCL-inv-path", type=str, help="path to the HCL_inv fst file (torch_saved)", required=True) + parser.add_argument("--G-path", type=str, help="path to the LM fsa (openfst text format or torch saved)", required=True) parser.add_argument("--out-dir", type=str, default="data", help="directory to save the decoding graph") # fmt: on @@ -40,20 +40,21 @@ def main(args): except ImportError: raise ImportError("Please install k2 by `pip install k2`") - HCL_inv = k2.Fsa.from_dict(torch.load(args.HCL_fst_path)).invert_() - HCL_inv = k2.arc_sort(HCL_inv) + HCL_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(args.HCL_inv_path))) if args.lm_fsa_path[-3:] == ".pt": - G = k2.Fsa.from_dict(torch.load(args.lm_fsa_path)) + G = k2.Fsa.from_dict(torch.load(args.G_path)) else: with open(args.lm_fsa_path, "r", encoding="utf-8") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=True) assert not hasattr(G, "aux_labels") G = k2.arc_sort(G) - decoding_graph = k2.intersect(HCL_inv, G).invert_() + HCLG = k2.invert(k2.connect(k2.intersect(G, HCL_inv))) + HCLG = k2.determinize(HCLG) + HCLG = k2.connect(HCLG) save_path = os.path.join(args.out_dir, "HCLG.pt") - torch.save(decoding_graph.as_dict(), save_path) + torch.save(HCLG.as_dict(), save_path) logger.info(f"saved the decoding graph as {save_path}") diff --git a/examples/mobvoihotwords/local/data_prep.py b/examples/mobvoihotwords/local/data_prep.py index 916983363..52f0adfb9 100755 --- a/examples/mobvoihotwords/local/data_prep.py +++ b/examples/mobvoihotwords/local/data_prep.py @@ -23,11 +23,10 @@ try: from lhotse import ( - CutSet, Mfcc, MfccConfig, LilcomFilesWriter, RecordingSet, SupervisionSet + CutSet, Mfcc, MfccConfig, LilcomFilesWriter, RecordingSet, SupervisionSet, combine ) from lhotse.augmentation import SoxEffectTransform, RandomValue - from lhotse.manipulation import combine - from lhotse.recipes.mobvoihotwords import download_and_untar, prepare_mobvoihotwords + from lhotse.recipes.mobvoihotwords import download_mobvoihotwords, prepare_mobvoihotwords except ImportError: raise ImportError("Please install Lhotse by `pip install lhotse`") @@ -79,22 +78,10 @@ def main(args): output_dir = root_dir logger.info(f"Download and extract the corpus") - download_and_untar(root_dir) + download_mobvoihotwords(root_dir) logger.info(f"Prepare the manifests") - data_parts = ["train", "dev", "test"] - try: - mobvoihotwords_manifests = defaultdict(dict) - for part in data_parts: - mobvoihotwords_manifests[part]["recordings"] = RecordingSet.from_json( - output_dir / f"recordings_{part}.json" - ) - mobvoihotwords_manifests[part]["supervisions"] = SupervisionSet.from_json( - output_dir / f"supervisions_{part}.json" - ) - except Exception as e: - logger.warning("Mobvoihotwords manifests not found on disk, preparing them from scratch: " + str(e)) - mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) + mobvoihotwords_manifests = prepare_mobvoihotwords(corpus_dir, output_dir) logger.info( "train/dev/test size: {}/{}/{}".format( len(mobvoihotwords_manifests["train"]["recordings"]), @@ -134,21 +121,20 @@ def main(args): overlap_duration=args.overlap_duration, ) - # "clean" set - logger.info(f"Extract features for '{part}' set") - json_path = output_dir / f"cuts_{part}_clean.json.gz" + # "clean + speed-perturbed" set + logger.info(f"Extract features for '{part}' set with speed perturbation") + json_path = output_dir / f"cuts_{part}_sp.json.gz" if json_path.is_file(): logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") - cut_set_clean = CutSet.from_json(json_path) + cut_set_sp = CutSet.from_json(json_path) else: - with LilcomFilesWriter(f"{output_dir}/feats_{part}_clean") as storage: - cut_set_clean = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augment_fn=None, - executor=ex, - ) - cut_set_clean.to_json(json_path) + cut_set_sp = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set_sp = cut_set_sp.compute_and_store_features( + extractor=mfcc, + storage=LilcomFilesWriter(f"{output_dir}/feats_{part}_sp"), + executor=ex, + ) + cut_set_sp.to_json(json_path) # augmented with reverberation logger.info(f"Extract features for '{part}' set with reverberation") @@ -158,68 +144,27 @@ def main(args): cut_set_rev = CutSet.from_json(json_path) else: augment_fn = SoxEffectTransform(effects=reverb(sampling_rate=sampling_rate)) - with LilcomFilesWriter(f"{output_dir}/feats_{part}_rev") as storage: - with numpy_seed(args.seed): - cut_set_rev = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augment_fn=augment_fn, - executor=ex, - ) - cut_set_rev = CutSet.from_cuts( - cut.with_id("rev-" + cut.id) for cut in cut_set_rev - ) - cut_set_rev.to_json(json_path) - - # augmented with speed perturbation - logger.info(f"Extract features for '{part}' set with speed perturbation") - json_path = output_dir / f"cuts_{part}_sp1.1.json.gz" - if json_path.is_file(): - logger.info(f"{json_path} exists, skip the extraction (remove it if you want to re-generate it)") - cut_set_sp1p1 = CutSet.from_json(json_path) - else: - augment_fn = SoxEffectTransform(effects=speed(sampling_rate=sampling_rate, factor=1.1)) - with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp1.1") as storage: - cut_set_sp1p1 = cut_set.compute_and_store_features( + with numpy_seed(args.seed): + cut_set_rev = cut_set.modify_ids( + lambda cut_id: f"rev-{cut_id}" + ).compute_and_store_features( extractor=mfcc, - storage=storage, + storage=LilcomFilesWriter(f"{output_dir}/feats_{part}_rev"), augment_fn=augment_fn, executor=ex, ) - cut_set_sp1p1 = CutSet.from_cuts( - cut.with_id("sp1.1-" + cut.id) for cut in cut_set_sp1p1 - ) - cut_set_sp1p1.to_json(json_path) - json_path = output_dir / f"cuts_{part}_sp0.9.json.gz" - if json_path.is_file(): - logger.info(f"{json_path} exists, skip the extraction") - cut_set_sp1p1 = CutSet.from_json(json_path) - else: - augment_fn = SoxEffectTransform(effects=speed(sampling_rate=sampling_rate, factor=0.9)) - with LilcomFilesWriter(f"{output_dir}/feats_{part}_sp0.9") as storage: - cut_set_sp0p9 = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augment_fn=augment_fn, - executor=ex, - ) - cut_set_sp0p9 = CutSet.from_cuts( - cut.with_id("sp0.9-" + cut.id) for cut in cut_set_sp0p9 - ) - cut_set_sp0p9.to_json(json_path) + cut_set_rev.to_json(json_path) - # combine the clean and augmented sets together + # combine all the augmented sets together logger.info(f"Combine all the features above") - cut_set = combine(cut_set_clean, cut_set_rev, cut_set_sp1p1, cut_set_sp0p9) + cut_set = combine(cut_set_sp, cut_set_rev) else: # no augmentations for dev and test sets logger.info(f"extract features for '{part}' set") - with LilcomFilesWriter(f"{output_dir}/feats_{part}") as storage: - cut_set = cut_set.compute_and_store_features( - extractor=mfcc, - storage=storage, - augment_fn=None, - executor=ex, - ) + cut_set = cut_set.compute_and_store_features( + extractor=mfcc, + storage=LilcomFilesWriter(f"{output_dir}/feats_{part}"), + executor=ex, + ) mobvoihotwords_manifests[part]["cuts"] = cut_set cut_set.to_json(output_dir / f"cuts_{part}.json.gz") @@ -306,14 +251,6 @@ def reverb(sampling_rate: int) -> List[List[str]]: ] -def speed(sampling_rate: int, factor: float) -> List[List[str]]: - return [ - # speed perturbation with a factor - ["speed", factor], - ["rate", sampling_rate], # Resample back to the original sampling rate (speed changes it) - ] - - if __name__ == "__main__": parser = get_parser() args = parser.parse_args() diff --git a/examples/mobvoihotwords/local/decode_best_path.py b/examples/mobvoihotwords/local/decode_best_path.py index e3837704e..2fa6be205 100755 --- a/examples/mobvoihotwords/local/decode_best_path.py +++ b/examples/mobvoihotwords/local/decode_best_path.py @@ -57,11 +57,19 @@ def main(args): net_output = torch.from_numpy(kaldi_io.read_mat(rxfile)).unsqueeze(0) # 1 x T x V supervision_segments = net_output.new_tensor([0, 0, net_output.size(0)], dtype=torch.int).unsqueeze(0) # 1 x 3 dense_fsa_vec = k2.DenseFsaVec(net_output, supervision_segments) + graph = graph.to(dense_fsa_vec.device) graph_unrolled = k2.intersect_dense_pruned( graph, dense_fsa_vec, search_beam=args.beam, output_beam=15.0, min_active_states=0, max_active_states=10000 ) - best_path = k2.shortest_path(graph_unrolled, use_float_scores=True) - hyp = [symbol_table._id2sym[x.item()] for x in best_path[0].aux_labels if x > 0] + best_path = k2.shortest_path(graph_unrolled, use_double_scores=False) + if isinstance(best_path[0].aux_labels, torch.Tensor): + aux_labels = best_paths[0].aux_labels + else: + # it's a ragged tensor + aux_labels = best_path[0].aux_labels.values() + aux_labels = aux_labels[aux_labels > 0] + aux_labels = aux_labels.tolist() + hyp = [symbol_table.get(x) for x in aux_labels] print(utt_id, hyp, file=f_out) num_processed += 1 diff --git a/examples/mobvoihotwords/local/generate_graphs.py b/examples/mobvoihotwords/local/generate_graphs.py index db981cc7d..94aac217b 100755 --- a/examples/mobvoihotwords/local/generate_graphs.py +++ b/examples/mobvoihotwords/local/generate_graphs.py @@ -26,9 +26,9 @@ def get_parser(): description="Generate graphs for training" ) # fmt: off - parser.add_argument("--hmm-paths", nargs="+", help="list of HMM paths", required=True) - parser.add_argument("--lexicon-fst-path", type=str, help="path to the lexicon fst", required=True) - parser.add_argument("--phone-lm-fsa-path", type=str, help="path to the phone LM fsa", required=True) + parser.add_argument("--hmm-paths", nargs="+", help="list of HMM paths (in openfst text format)", required=True) + parser.add_argument("--L-path", type=str, help="path to L fst (in openfst text formet)", required=True) + parser.add_argument("--phone-lm-fsa-path", type=str, help="path to the phone LM fsa (in openfst text format)", required=True) parser.add_argument("--out-dir", type=str, default="data", help="directory to save output graphs") # fmt: on @@ -47,39 +47,39 @@ def main(args): hmms.append(k2.Fsa.from_openfst(f.read(), acceptor=False)) hmms[-1] = k2.arc_sort(hmms[-1]) hmm_vec = k2.create_fsa_vec(hmms) - H = k2.closure(k2.union(hmm_vec)) - H_inv = k2.arc_sort(H.invert_()) - - with open(args.lexicon_fst_path, "r", encoding="utf-8") as f: - L = k2.Fsa.from_openfst(f.read(), acceptor=False) - L = k2.arc_sort(L) + # temporarily +1 to make label 0 of HMMs (means "blank") be treated as normal label + hmm_vec.labels = torch.where(hmm_vec.labels >= 0, hmm_vec.labels + 1, hmm_vec.labels) + H = k2.connect(k2.remove_epsilons_iterative_tropical(k2.closure(k2.union(hmm_vec)))) + H.labels = torch.where(H.labels > 0, H.labels - 1, H.labels) # restore the label indices + H_inv = k2.arc_sort(k2.invert(H)) + #save_path = os.path.join(args.out_dir, "Hinv.pt") + #torch.save(H_inv.as_dict(), save_path) + #logger.info(f"saved H_inv as {save_path}") + + with open(args.L_path, "r", encoding="utf-8") as f: + L = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=False)) with open(args.phone_lm_fsa_path, "r", encoding="utf-8") as f: - phone_lm = k2.Fsa.from_openfst(f.read(), acceptor=True) - assert not hasattr(phone_lm, "aux_labels") - phone_lm = k2.arc_sort(phone_lm) + phone_lm = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=True)) # emulate composition - if hasattr(L, "aux_labels"): - L.temp_labels = L.aux_labels - del L.aux_labels - if hasattr(L, "aux_symbols"): - L.temp_symbols = L.aux_symbols - del L.aux_symbols - HL = k2.intersect(H_inv, L).invert_() - if hasattr(HL, "temp_labels"): - HL.aux_labels = HL.temp_labels - del HL.temp_labels - if hasattr(HL, "temp_symbols"): - HL.aux_symbols = HL.temp_symbols - del HL.temp_symbols - HL = k2.arc_sort(HL) - save_path = os.path.join(args.out_dir, "HL.pt") - torch.save(HL.as_dict(), save_path) - logger.info(f"saved the HL fst as {save_path}") - - den_graph = k2.intersect(H_inv, phone_lm).invert_() - den_graph = k2.arc_sort(den_graph) + #L_clone = L.clone() + #if hasattr(L, "aux_labels"): + # L.temp_labels = L.aux_labels + # del L.aux_labels + #HL = k2.invert(k2.connect(k2.intersect(L, H_inv))) + #if hasattr(HL, "temp_labels"): + # HL.aux_labels = HL.temp_labels + # del HL.temp_labels + #HL_inv = k2.arc_sort(k2.invert(HL)) + HL_inv = k2.arc_sort(k2.connect(k2.compose(L.invert(), H_inv))) + #print(k2.is_rand_equivalent(HL_inv,HL_inv_new,log_semiring=True)) + save_path = os.path.join(args.out_dir, "HLinv.pt") + torch.save(HL_inv.as_dict(), save_path) + logger.info(f"saved the HL_inv fst as {save_path}") + + den_graph = k2.arc_sort(k2.invert(k2.connect(k2.intersect(H_inv, phone_lm)))) + del den_graph.aux_labels save_path = os.path.join(args.out_dir, "denominator.pt") torch.save(den_graph.as_dict(), save_path) logger.info(f"saved the denominator graph as {save_path}") diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh index 4c5bf7773..1fb6e5bca 100755 --- a/examples/mobvoihotwords/run.sh +++ b/examples/mobvoihotwords/run.sh @@ -37,9 +37,9 @@ if [ ${stage} -le 1 ]; then echo "Prepare the lexicon" mkdir -p data/lang cat > data/lang/lexiconp.txt < 1.0 SIL EOF @@ -50,9 +50,9 @@ EOF cat > data/lang/phones.txt < 0 SIL 1 -hixiaowen 2 -nihaowenwen 3 -freetext 4 +freetext 2 +hixiaowen 3 +nihaowenwen 4 EOF echo "Prepare words symbol table" @@ -72,7 +72,6 @@ EOF id_freetext=`cat data/lang/phones.txt | grep "freetext" | awk '{print $2}'` id_word0=`cat data/lang/phones.txt | grep "hixiaowen" | awk '{print $2}'` id_word1=`cat data/lang/phones.txt | grep "nihaowenwen" | awk '{print $2}'` - id_freetext=`cat data/lang/phones.txt | grep "freetext" | awk '{print $2}'` cat > data/lang/hmm_sil.fst.txt < data/lang/phone_lm.fsa.txt + cat > data/lang/phone_lm.fsa.txt <&1 | tee $log_file + --denominator-graph-path data/denominator.pt --HCL-inv-path data/HLinv.pt \ + --max-source-positions 9999 --max-target-positions 9999 $opts || exit 1; fi if [ ${stage} -le 3 ]; then @@ -218,7 +217,7 @@ if [ ${stage} -le 4 ]; then 3 0 $sil_id 0 EOF - local/create_decoding_graph.py --HCL-fst-path data/HL.pt --lm-fsa-path $lang_test/lm/fsa.txt $lang_test/graph || exit 1; + local/create_decoding_graph.py --HCL-inv-path data/HLinv.pt --G-path $lang_test/lm/fsa.txt $lang_test/graph || exit 1; rm $dir/.error 2>/dev/null || true for dataset in $test_set; do From 8a345cb28dccd5cefe5d5e45b83e710c0ba1efa0 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 27 Dec 2020 04:30:39 -0500 Subject: [PATCH 118/119] fix negative loss --- espresso/criterions/k2_lf_mmi_loss.py | 47 +++++++++------ espresso/dump_posteriors.py | 4 +- ..._graphs.py => create_H_and_denominator.py} | 37 ++++-------- .../local/create_decoding_graph.py | 60 ++++++++++++++----- .../mobvoihotwords/local/decode_best_path.py | 12 ++-- examples/mobvoihotwords/local/evaluate.py | 4 +- examples/mobvoihotwords/run.sh | 49 +++++++++------ 7 files changed, 127 insertions(+), 86 deletions(-) rename examples/mobvoihotwords/local/{generate_graphs.py => create_H_and_denominator.py} (59%) diff --git a/espresso/criterions/k2_lf_mmi_loss.py b/espresso/criterions/k2_lf_mmi_loss.py index c8b213e63..768c7c841 100644 --- a/espresso/criterions/k2_lf_mmi_loss.py +++ b/espresso/criterions/k2_lf_mmi_loss.py @@ -35,12 +35,18 @@ class K2LatticeFreeMMICriterionConfig(FairseqDataclass): denominator_graph_path: str = field( default="???", metadata={"help": "path to the denominator graph file (torch saved)"} ) - HCL_inv_path: str = field( - default="???", metadata={"help": "path to the HCL_inv fst file (torch saved)"} + H_path: str = field( + default="???", metadata={"help": "path to the H fst file (torch saved). Note: pdf-ids are offset by +1"} + ) + L_path: str = field( + default="???", metadata={"help": "path to the L fst file (openfst text format or torch saved)"} ) word_symbol_table_path: str = field( default="???", metadata={"help": "path to the word symbol table file"} ) + phone_symbol_table_path: str = field( + default="???", metadata={"help": "path to the phone symbol table file"} + ) xent_regularization_coefficient: float = field( default=0.0, metadata={"help": "cross-entropy regularization coefficient"}, @@ -48,7 +54,7 @@ class K2LatticeFreeMMICriterionConfig(FairseqDataclass): def compile_numerator_graphs( - texts: List[str], symbols: k2.SymbolTable, HCL_inv: k2.Fsa, + texts: List[str], symbols: k2.SymbolTable, H_inv: k2.Fsa, L_inv: k2.Fsa, first_phone_disambig_id: int, unk_str: Optional[str] = "UNK", den_graph: Optional[k2.Fsa] = None ): assert len(den_graph.shape) == 2 @@ -58,20 +64,22 @@ def compile_one_and_cache(text: str) -> k2.Fsa: filtered_text = [token if token in symbols._sym2id else unk_str for token in text.split(" ")] word_ids = [symbols.get(word) for word in filtered_text] fsa = k2.linear_fsa(word_ids) - #if H_inv is not None and L_inv is not None: - # LG = k2.connect(k2.intersect(fsa, L_inv)).invert_() - # del LG.aux_labels - # num_graph = k2.arc_sort(k2.invert(k2.connect(k2.intersect(H_inv, LG)))) - #else: - # assert HCL_inv is not None - num_graph = k2.invert(k2.connect(k2.intersect(fsa, HCL_inv))) + LG = k2.connect(k2.intersect(fsa, L_inv)).invert_() + LG = k2.connect(k2.determinize(LG)) + LG.labels[LG.labels >= first_phone_disambig_id] = 0 + LG = k2.arc_sort(k2.connect(k2.remove_epsilons_iterative_tropical(LG))) + del LG.aux_labels + num_graph = k2.arc_sort(k2.invert(k2.connect(k2.intersect(H_inv, LG)))) + num_graph = k2.connect(k2.remove_epsilons_iterative_tropical(num_graph)) + num_graph = k2.connect(k2.determinize(num_graph)) + del num_graph.aux_labels + num_graph.labels = torch.where(num_graph.labels > 0, num_graph.labels - 1, num_graph.labels) if den_graph is not None: num_graph = k2.arc_sort(num_graph) # zero the score before intersect to avoid double counting num_graph.scores = num_graph.scores.new_zeros(num_graph.scores.size()) # treat epsilon as normal labels, i.e., blanks num_graph = k2.connect(k2.intersect(num_graph, den_graph, treat_epsilons_specially=False)) - del num_graph.aux_labels return num_graph num_graphs = k2.create_fsa_vec([compile_one_and_cache(text) for text in texts]) @@ -92,11 +100,15 @@ def __init__(self, cfg: K2LatticeFreeMMICriterionConfig, task: FairseqTask): del self.den_graph.aux_labels self.den_graph.scores.requires_grad_(False) self.den_graph_cpu = self.den_graph[0].clone() # to be intersect with a individual numerator fsa - self.HCL_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(cfg.HCL_inv_path))) - #self.H_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(cfg.H_inv_path))) - #with open(cfg.L_path, "r", encoding="utf-8") as f: - # self.L_inv = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=False).invert_()) + self.H_inv = k2.arc_sort(k2.invert(k2.Fsa.from_dict(torch.load(cfg.H_path)))) + if cfg.L_path[-3:] == ".pt": + self.L_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(args.L_path)).invert_()) + else: + with open(cfg.L_path, "r", encoding="utf-8") as f: + self.L_inv = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=False).invert_()) self.symbol_table = k2.SymbolTable.from_file(cfg.word_symbol_table_path) + phone_symbol_table = k2.SymbolTable.from_file(cfg.phone_symbol_table_path) + self.first_phone_disambig_id = min(v for k, v in phone_symbol_table._sym2id.items() if k.startswith("#")) self.xent_regularize = cfg.xent_regularization_coefficient self.subsampling_factor = None @@ -134,8 +146,6 @@ def compute_loss( ): # create the dense fsts from the network's output encoder_out = net_output["encoder_out"][0].transpose(0, 1) # T x B x V -> B x T x V - if torch.isnan(encoder_out).int().sum().item() > 0 or torch.isinf(encoder_out).int().sum().item() > 0: - print("nan",torch.isnan(encoder_out).int().sum().item(), "inf", torch.isinf(encoder_out).int().sum().item()) encoder_out = encoder_out.clamp(-30, 30) # clamp to avoid numerical overflows out_lengths = net_output["src_lengths"][0] # B supervision_segments = torch.stack( @@ -151,7 +161,8 @@ def compute_loss( # numerator computation num_graphs = compile_numerator_graphs( - sample["target"]["text"], self.symbol_table, HCL_inv=self.HCL_inv, unk_str=self.unk_str, den_graph=self.den_graph_cpu + sample["target"]["text"], self.symbol_table, H_inv=self.H_inv, L_inv=self.L_inv, first_phone_disambig_id=self.first_phone_disambig_id, + unk_str=self.unk_str, den_graph=self.den_graph_cpu ).to(encoder_out.device) num_graphs.scores.requires_grad_(False) num_graphs_unrolled = k2.intersect_dense_pruned( diff --git a/espresso/dump_posteriors.py b/espresso/dump_posteriors.py index 62c02cd37..0119259c8 100755 --- a/espresso/dump_posteriors.py +++ b/espresso/dump_posteriors.py @@ -165,10 +165,10 @@ def _main(cfg, output_file): if out_lengths is not None: for i in range(sample["nsentences"]): length = out_lengths[i] - kaldi_io.write_mat(f, lprobs[i, : length, :].cpu().numpy(), key=sample["utt_id"][i]) + kaldi_io.write_mat(f, lprobs[i, : length, :].cpu().numpy(), key=sample["reco_id"][i] if "reco_id" in sample else sample["utt_id"][i]) else: for i in range(sample["nsentences"]): - kaldi_io.write_mat(f, lprobs[i, :, :].cpu().numpy(), key=sample["utt_id"][i]) + kaldi_io.write_mat(f, lprobs[i, :, :].cpu().numpy(), key=sample["reco_id"][i] if "reco_id" in sample else sample["utt_id"][i]) else: # dumping chunks within the same utterance from left to right for sample in progress: # sample is actually a list of batches sample = utils.move_to_cuda(sample) if use_cuda else sample diff --git a/examples/mobvoihotwords/local/generate_graphs.py b/examples/mobvoihotwords/local/create_H_and_denominator.py similarity index 59% rename from examples/mobvoihotwords/local/generate_graphs.py rename to examples/mobvoihotwords/local/create_H_and_denominator.py index 94aac217b..f5a10685d 100755 --- a/examples/mobvoihotwords/local/generate_graphs.py +++ b/examples/mobvoihotwords/local/create_H_and_denominator.py @@ -18,7 +18,7 @@ level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) -logger = logging.getLogger("mobvoihotwords.generate_graphs") +logger = logging.getLogger("mobvoihotwords.create_H_and_denominator") def get_parser(): @@ -27,7 +27,6 @@ def get_parser(): ) # fmt: off parser.add_argument("--hmm-paths", nargs="+", help="list of HMM paths (in openfst text format)", required=True) - parser.add_argument("--L-path", type=str, help="path to L fst (in openfst text formet)", required=True) parser.add_argument("--phone-lm-fsa-path", type=str, help="path to the phone LM fsa (in openfst text format)", required=True) parser.add_argument("--out-dir", type=str, default="data", help="directory to save output graphs") # fmt: on @@ -49,36 +48,20 @@ def main(args): hmm_vec = k2.create_fsa_vec(hmms) # temporarily +1 to make label 0 of HMMs (means "blank") be treated as normal label hmm_vec.labels = torch.where(hmm_vec.labels >= 0, hmm_vec.labels + 1, hmm_vec.labels) - H = k2.connect(k2.remove_epsilons_iterative_tropical(k2.closure(k2.union(hmm_vec)))) - H.labels = torch.where(H.labels > 0, H.labels - 1, H.labels) # restore the label indices + H = k2.closure(k2.union(hmm_vec)) H_inv = k2.arc_sort(k2.invert(H)) - #save_path = os.path.join(args.out_dir, "Hinv.pt") - #torch.save(H_inv.as_dict(), save_path) - #logger.info(f"saved H_inv as {save_path}") - - with open(args.L_path, "r", encoding="utf-8") as f: - L = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=False)) + save_path = os.path.join(args.out_dir, "H.pt") # H's pdf-ids are offset by +1 + torch.save(H.as_dict(), save_path) + logger.info(f"saved H as {save_path}") with open(args.phone_lm_fsa_path, "r", encoding="utf-8") as f: phone_lm = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=True)) - # emulate composition - #L_clone = L.clone() - #if hasattr(L, "aux_labels"): - # L.temp_labels = L.aux_labels - # del L.aux_labels - #HL = k2.invert(k2.connect(k2.intersect(L, H_inv))) - #if hasattr(HL, "temp_labels"): - # HL.aux_labels = HL.temp_labels - # del HL.temp_labels - #HL_inv = k2.arc_sort(k2.invert(HL)) - HL_inv = k2.arc_sort(k2.connect(k2.compose(L.invert(), H_inv))) - #print(k2.is_rand_equivalent(HL_inv,HL_inv_new,log_semiring=True)) - save_path = os.path.join(args.out_dir, "HLinv.pt") - torch.save(HL_inv.as_dict(), save_path) - logger.info(f"saved the HL_inv fst as {save_path}") - - den_graph = k2.arc_sort(k2.invert(k2.connect(k2.intersect(H_inv, phone_lm)))) + den_graph = k2.invert(k2.connect(k2.intersect(H_inv, phone_lm))) + den_graph = k2.connect(k2.remove_epsilons_iterative_tropical(den_graph)) + den_graph = k2.arc_sort(k2.connect(k2.determinize(den_graph))) + assert (den_graph.labels == 0).int().sum().item() == 0 + den_graph.labels = torch.where(den_graph.labels > 0, den_graph.labels - 1, den_graph.labels) # restore the label indices del den_graph.aux_labels save_path = os.path.join(args.out_dir, "denominator.pt") torch.save(den_graph.as_dict(), save_path) diff --git a/examples/mobvoihotwords/local/create_decoding_graph.py b/examples/mobvoihotwords/local/create_decoding_graph.py index 23e00d86d..8ba77c969 100755 --- a/examples/mobvoihotwords/local/create_decoding_graph.py +++ b/examples/mobvoihotwords/local/create_decoding_graph.py @@ -26,9 +26,18 @@ def get_parser(): description="Create the decoding graph for decoding" ) # fmt: off - parser.add_argument("--HCL-inv-path", type=str, help="path to the HCL_inv fst file (torch_saved)", required=True) + parser.add_argument("--H-path", type=str, help="path to the H fst file (torch_saved). Note: pdf-ids are offset by +1", required=True) + parser.add_argument("--L-path", type=str, help="path to the L fst file (openfst text format or torch saved)", required=True) parser.add_argument("--G-path", type=str, help="path to the LM fsa (openfst text format or torch saved)", required=True) - parser.add_argument("--out-dir", type=str, default="data", help="directory to save the decoding graph") + parser.add_argument( + "--first-phone-disambig-id", type=int, default=999, + help="An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet" + ) + parser.add_argument( + "--first-word-disambig-id", type=int, default=999999, + help="An integer ID corresponding to the first disambiguation symbol in the words vocabulary" + ) + parser.add_argument("out_dir", type=str, help="directory to save the decoding graph") # fmt: on return parser @@ -40,21 +49,44 @@ def main(args): except ImportError: raise ImportError("Please install k2 by `pip install k2`") - HCL_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(args.HCL_inv_path))) + H_inv = k2.arc_sort(k2.invert(k2.Fsa.from_dict(torch.load(args.H_path)))) - if args.lm_fsa_path[-3:] == ".pt": - G = k2.Fsa.from_dict(torch.load(args.G_path)) + if args.L_path[-3:] == ".pt": + L_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(args.L_path)).invert_()) else: - with open(args.lm_fsa_path, "r", encoding="utf-8") as f: - G = k2.Fsa.from_openfst(f.read(), acceptor=True) - assert not hasattr(G, "aux_labels") - G = k2.arc_sort(G) - - HCLG = k2.invert(k2.connect(k2.intersect(G, HCL_inv))) - HCLG = k2.determinize(HCLG) - HCLG = k2.connect(HCLG) + with open(args.L_path, "r", encoding="utf-8") as f: + L_inv = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=False).invert_()) + + if args.G_path[-3:] == ".pt": + G = k2.arc_sort(k2.Fsa.from_dict(torch.load(args.G_path))) + else: + with open(args.G_path, "r", encoding="utf-8") as f: + G = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=True)) + + LG = k2.connect(k2.intersect(G, L_inv)).invert_() + LG = k2.connect(k2.determinize(LG)) + LG.labels[LG.labels >= args.first_phone_disambig_id] = 0 + if isinstance(LG.aux_labels, torch.Tensor): + LG.aux_labels[LG.aux_labels >= args.first_word_disambig_id] = 0 + else: + LG.aux_labels.values()[LG.aux_labels.values() >= args.first_word_disambig_id] = 0 + LG = k2.arc_sort(k2.connect(k2.remove_epsilons_iterative_tropical(LG))) + + LG.temp_labels = LG.aux_labels + del LG.aux_labels + HLG_inv = k2.connect(k2.intersect(LG, H_inv)) + HLG_inv.labels = HLG_inv.temp_labels + del HLG_inv.temp_labels + HLG = k2.invert(HLG_inv) + + HLG = k2.connect(k2.remove_epsilons_iterative_tropical(HLG)) + #HLG = k2.connect(k2.determinize(HLG)) + HLG.labels = torch.where(HLG.labels > 0, HLG.labels - 1, HLG.labels) + HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) + + os.makedirs(args.out_dir, exist_ok=True) save_path = os.path.join(args.out_dir, "HCLG.pt") - torch.save(HCLG.as_dict(), save_path) + torch.save(HLG.as_dict(), save_path) logger.info(f"saved the decoding graph as {save_path}") diff --git a/examples/mobvoihotwords/local/decode_best_path.py b/examples/mobvoihotwords/local/decode_best_path.py index 2fa6be205..8e5d23a4f 100755 --- a/examples/mobvoihotwords/local/decode_best_path.py +++ b/examples/mobvoihotwords/local/decode_best_path.py @@ -9,6 +9,8 @@ import os import sys +import numpy as np + import torch @@ -47,15 +49,15 @@ def main(args): raise ImportError("Please install kaldi_io by `pip install kaldi_io`") symbol_table = k2.SymbolTable.from_file(args.word_symbol_table) - graph = k2.Fsa.from_dict(torch.load(args.args.decoding_graph)) + graph = k2.Fsa.from_dict(torch.load(args.decoding_graph)) graph.scores.requires_grad_(False) num_processed = 0 - with open(args.net_output, "r", encoding="utf-8") as f_in, open(args.hyp_file, "r", encoding="utf-8") as f_out: + with open(args.net_output, "r", encoding="utf-8") as f_in, open(args.hyp_file, "w", encoding="utf-8") as f_out: for line in f_in: utt_id, rxfile = line.strip().split(maxsplit=1) - net_output = torch.from_numpy(kaldi_io.read_mat(rxfile)).unsqueeze(0) # 1 x T x V - supervision_segments = net_output.new_tensor([0, 0, net_output.size(0)], dtype=torch.int).unsqueeze(0) # 1 x 3 + net_output = torch.from_numpy(np.array(kaldi_io.read_mat(rxfile))).float().unsqueeze(0) # 1 x T x V + supervision_segments = net_output.new_tensor([0, 0, net_output.size(1)], dtype=torch.int).unsqueeze(0) # 1 x 3 dense_fsa_vec = k2.DenseFsaVec(net_output, supervision_segments) graph = graph.to(dense_fsa_vec.device) graph_unrolled = k2.intersect_dense_pruned( @@ -70,7 +72,7 @@ def main(args): aux_labels = aux_labels[aux_labels > 0] aux_labels = aux_labels.tolist() hyp = [symbol_table.get(x) for x in aux_labels] - print(utt_id, hyp, file=f_out) + print(utt_id, " ".join(hyp), file=f_out) num_processed += 1 logger.info(f"Processed {num_processed} utterances") diff --git a/examples/mobvoihotwords/local/evaluate.py b/examples/mobvoihotwords/local/evaluate.py index 3717c620f..e42903337 100755 --- a/examples/mobvoihotwords/local/evaluate.py +++ b/examples/mobvoihotwords/local/evaluate.py @@ -41,7 +41,7 @@ def main(args): except ImportError: raise ImportError("Please install Lhotse by `pip install lhotse`") - supervisions = SupervisionSet.from_json(args.recording_file) # one and only one supervision segment per recording + supervisions = SupervisionSet.from_json(args.supervsion_file) # one and only one supervision segment per recording neg_dur = sum(sup.duration for sup in supervisions if sup.text != args.wake_word) ref = [(sup.recording_id, sup.text) for sup in supervisions] @@ -78,7 +78,7 @@ def main(args): with open(args.result_file, "w", encoding="utf-8") as f: print( "precision: {:.5f} recall: {:.5f} FPR: {:.5f} FNR: {:.5f} FP per hour: {:.5f} total: {:d}".format( - precision, recall, false_positive_rate, false_negative_rate, false_alarms_per_hour, TP + TN + FP + FN + precision, recall, false_positive_rate, false_negative_rate, false_alarms_per_hour, int(TP + TN + FP + FN) ), file=f ) diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh index 1fb6e5bca..336c8baa3 100755 --- a/examples/mobvoihotwords/run.sh +++ b/examples/mobvoihotwords/run.sh @@ -43,8 +43,12 @@ NihaoWenwen 1.0 nihaowenwen 1.0 SIL EOF - utils/lang/make_lexicon_fst.py --sil-prob 0.5 --sil-phone SIL \ - data/lang/lexiconp.txt > data/lang/L.fst.sym + utils/lang/make_lexicon_fst.py --sil-prob 0.5 --sil-phone SIL --sil-disambig '#1' \ + data/lang/lexiconp.txt > data/lang/L_disambig.fst.sym + cat <(head -n -1 data/lang/L_disambig.fst.sym) <(echo -e "1\t1\t#0\t#0") <(tail -n 1 data/lang/L_disambig.fst.sym) \ + > data/lang/L_disambig.fst.sym.temp + cat data/lang/L_disambig.fst.sym.temp > data/lang/L_disambig.fst.sym + rm -f data/lang/L_disambig.fst.sym.temp echo "Prepare phones symbol table" cat > data/lang/phones.txt < data/lang/L.fst.txt + utils/sym2int.pl -f 3 data/lang/phones.txt data/lang/L_disambig.fst.txt echo "Prepare HMMs for phones" id_sil=`cat data/lang/phones.txt | grep "SIL" | awk '{print $2}'` @@ -132,9 +139,8 @@ EOF echo "Generate graphs for training" log_file=data/log/generate_graphs.log - $train_cmd $log_file local/generate_graphs.py --hmm-paths data/lang/hmm_{sil,freetext,hixiaowen,nihaowenwen}.fst.txt \ - --L-path data/lang/L.fst.txt --phone-lm-fsa-path data/lang/phone_lm.fsa.txt \ - --out-dir data + $train_cmd $log_file local/create_H_and_denominator.py --hmm-paths data/lang/hmm_{sil,freetext,hixiaowen,nihaowenwen}.fst.txt \ + --phone-lm-fsa-path data/lang/phone_lm.fsa.txt --out-dir data fi [ -z "$free_gpu" ] && [[ $(hostname -f) == *.clsp.jhu.edu ]] && free_gpu=$(free-gpu -n $ngpus) || \ @@ -163,7 +169,8 @@ if [ ${stage} -le 2 ]; then --save-dir $dir --restore-file checkpoint_last.pt --save-interval-updates $((1500/ngpus/update_freq)) \ --keep-interval-updates 5 --keep-last-epochs 5 --validate-interval 1 \ --criterion k2_lattice_free_mmi --num-targets $num_targets --word-symbol-table-path data/lang/words.txt \ - --denominator-graph-path data/denominator.pt --HCL-inv-path data/HLinv.pt \ + --phone-symbol-table-path data/lang/phones.txt --denominator-graph-path data/denominator.pt \ + --H-path data/H.pt --L-path data/lang/L_disambig.fst.txt \ --max-source-positions 9999 --max-target-positions 9999 $opts || exit 1; fi @@ -176,20 +183,26 @@ if [ ${stage} -le 3 ]; then $cuda_cmd $log_file dump_posteriors.py data --use-k2-dataset \ --task speech_recognition_hybrid --max-tokens 25600 --max-sentences 128 \ --num-shards 1 --shard-id 0 --num-targets $num_targets --gen-subset $dataset \ - --max-source-positions 9999 --path $path \ + --max-source-positions 9999 --max-target-positions 9999 --path $path \ \| copy-matrix ark:- ark,scp:$dir/decode_$dataset/posteriors.ark,$dir/decode_$dataset/posteriors.scp || exit 1; echo "log saved in $log_file" done fi if [ ${stage} -le 4 ]; then - echo "Stage 7: Decoding" + echo "Stage 4: Decoding" lang_test=data/lang_test rm -rf $lang_test cp -r data/lang $lang_test - utils/lang/make_lexicon_fst.py --sil-prob 0.0 --sil-phone SIL $lang_test/lexiconp.txt > $lang_test/L.fst.sym - utils/sym2int.pl -f 3 $lang_test/phones.txt <$lang_test/L.fst.sym - | \ - utils/sym2int.pl -f 4 $lang_test/words.txt - > $lang_test/L.fst.txt + utils/lang/make_lexicon_fst.py --sil-prob 0.0 --sil-phone SIL --sil-disambig '#1' \ + $lang_test/lexiconp.txt > $lang_test/L_disambig.fst.sym + cat <(head -n -1 $lang_test/L_disambig.fst.sym) <(echo -e "0\t0\t#0\t#0") <(tail -n 1 $lang_test/L_disambig.fst.sym) \ + > $lang_test/L_disambig.fst.sym.temp + cat $lang_test/L_disambig.fst.sym.temp > $lang_test/L_disambig.fst.sym + rm -f $lang_test/L_disambig.fst.sym.temp + + utils/sym2int.pl -f 3 $lang_test/phones.txt <$lang_test/L_disambig.fst.sym - | \ + utils/sym2int.pl -f 4 $lang_test/words.txt - > $lang_test/L_disambig.fst.txt for wake_word in $wake_word0 $wake_word1; do if [[ "$wake_word" == "$wake_word0" ]]; then @@ -210,26 +223,26 @@ if [ ${stage} -le 4 ]; then 0 1 $sil_id 0 4 $sil_id 7.0 1 4 $freetext_id 0.0 -4 0 $sil_id 0.0 +4 0 $sil_id 1 2 $id0 $wake_word0_cost 1 3 $id1 $wake_word1_cost 2 0 $sil_id 3 0 $sil_id 0 EOF - local/create_decoding_graph.py --HCL-inv-path data/HLinv.pt --G-path $lang_test/lm/fsa.txt $lang_test/graph || exit 1; + local/create_decoding_graph.py --H-path data/H.pt --L-path $lang_test/L_disambig.fst.txt --G-path $lang_test/lm/fsa.txt \ + --first-phone-disambig-id 5 --first-word-disambig-id 5 $lang_test/graph || exit 1; rm $dir/.error 2>/dev/null || true for dataset in $test_set; do ( - nj=30 score_dir=$dir/decode_$dataset/score_${wake_word}_${wake_word0_cost}_${wake_word1_cost} mkdir -p $score_dir $decode_cmd $dir/decode_$dataset/log/decode_${wake_word}.log \ local/decode_best_path.py --beam=10 --word-symbol-table $lang_test/words.txt \ - $lang_test/graph/HCLG.pt $dir/decode_$dataset/posteriors.scp $score_dir/hyp.txt + $lang_test/graph/HCLG.pt $dir/decode_$dataset/posteriors.scp $score_dir/hyp.txt || exit 1; local/evaluate.py --wake-word $wake_word \ - data/supervisions_${dataset}.json $score_dir/hyp.txt $score_dir/metrics + data/supervisions_${dataset}.json $score_dir/hyp.txt $score_dir/metrics || exit 1; ) || touch $dir/.error & done wait From 1b059663d90d45a46b91e10c17530e498ec0e9a0 Mon Sep 17 00:00:00 2001 From: freewym Date: Sun, 27 Dec 2020 22:21:33 -0500 Subject: [PATCH 119/119] refactor code --- espresso/criterions/k2_lf_mmi_loss.py | 94 +++++++++++++------ .../mobvoihotwords/local/decode_best_path.py | 2 +- examples/mobvoihotwords/run.sh | 8 +- 3 files changed, 68 insertions(+), 36 deletions(-) diff --git a/espresso/criterions/k2_lf_mmi_loss.py b/espresso/criterions/k2_lf_mmi_loss.py index 768c7c841..99c48c846 100644 --- a/espresso/criterions/k2_lf_mmi_loss.py +++ b/espresso/criterions/k2_lf_mmi_loss.py @@ -8,7 +8,7 @@ import logging import math from omegaconf import II -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional import torch from torch import Tensor @@ -53,62 +53,98 @@ class K2LatticeFreeMMICriterionConfig(FairseqDataclass): ) -def compile_numerator_graphs( - texts: List[str], symbols: k2.SymbolTable, H_inv: k2.Fsa, L_inv: k2.Fsa, first_phone_disambig_id: int, - unk_str: Optional[str] = "UNK", den_graph: Optional[k2.Fsa] = None -): - assert len(den_graph.shape) == 2 +class TrainingGraphCompiler(object): + """ + :class:`TrainingGraphCompiler` is used to create training graphs (numerator graphs) for LF-MMI. + + Args: + H_inv (k2.Fsa): invert of H. Note: H.labels has been offset by +1 + L_inv (k2.Fsa): invert of L + symbol_table (k2.SymbolTable): word symbol table + phone_symbol_table (k2.SymbolTable): phone symbol table + unk_str (optional: str): unk string + den_graph (optional: k2.Fsa): denominator graph, to be composed with numerator graphs for normalization + """ + def __init__( + self, H_inv: k2.Fsa, L_inv: k2.Fsa, symbol_table: k2.SymbolTable, phone_symbol_table: k2.SymbolTable, + unk_str: Optional[str] = None, den_graph: Optional[k2.Fsa] = None + ): + if H_inv.properties & k2.fsa_properties.ARC_SORTED == 0: + H_inv = k2.arc_sort(H_inv) + if L_inv.properties & k2.fsa_properties.ARC_SORTED == 0: + L_inv = k2.arc_sort(L_inv) + + if unk_str is not None: + assert unk_str in symbol_table + if den_graph is not None: + assert len(den_graph.shape) == 2 + + self.H_inv = H_inv + self.L_inv = L_inv + self.symbol_table = symbol_table + self.first_phone_disambig_id = min(v for k, v in phone_symbol_table._sym2id.items() if k.startswith("#")) + self.unk_str = unk_str + self.den_graph = den_graph + + def compile(self, texts: Iterable[str]) -> k2.Fsa: + num_graphs = k2.create_fsa_vec([self.compile_one_and_cache(text) for text in texts]) + num_graphs.requires_grad_(False) + return num_graphs @lru_cache(maxsize=100000) - def compile_one_and_cache(text: str) -> k2.Fsa: - filtered_text = [token if token in symbols._sym2id else unk_str for token in text.split(" ")] - word_ids = [symbols.get(word) for word in filtered_text] + def compile_one_and_cache(self, text: str) -> k2.Fsa: + if self.unk_str is not None: + tokens = [token if token in self.symbol_table._sym2id else self.unk_str for token in text.split(" ")] + else: + tokens = [token for token in text.split(" ") if token in self.symbol_table._sym2id] + word_ids = [self.symbol_table[token] for token in tokens] fsa = k2.linear_fsa(word_ids) - LG = k2.connect(k2.intersect(fsa, L_inv)).invert_() + LG = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() LG = k2.connect(k2.determinize(LG)) - LG.labels[LG.labels >= first_phone_disambig_id] = 0 + LG.labels[LG.labels >= self.first_phone_disambig_id] = 0 LG = k2.arc_sort(k2.connect(k2.remove_epsilons_iterative_tropical(LG))) del LG.aux_labels - num_graph = k2.arc_sort(k2.invert(k2.connect(k2.intersect(H_inv, LG)))) + num_graph = k2.arc_sort(k2.invert(k2.connect(k2.intersect(self.H_inv, LG)))) num_graph = k2.connect(k2.remove_epsilons_iterative_tropical(num_graph)) num_graph = k2.connect(k2.determinize(num_graph)) del num_graph.aux_labels num_graph.labels = torch.where(num_graph.labels > 0, num_graph.labels - 1, num_graph.labels) - if den_graph is not None: + if self.den_graph is not None: num_graph = k2.arc_sort(num_graph) # zero the score before intersect to avoid double counting num_graph.scores = num_graph.scores.new_zeros(num_graph.scores.size()) # treat epsilon as normal labels, i.e., blanks - num_graph = k2.connect(k2.intersect(num_graph, den_graph, treat_epsilons_specially=False)) + num_graph = k2.connect(k2.intersect(num_graph, self.den_graph, treat_epsilons_specially=False)) return num_graph - num_graphs = k2.create_fsa_vec([compile_one_and_cache(text) for text in texts]) - return num_graphs - @register_criterion("k2_lattice_free_mmi", dataclass=K2LatticeFreeMMICriterionConfig) class K2LatticeFreeMMICriterion(FairseqCriterion): def __init__(self, cfg: K2LatticeFreeMMICriterionConfig, task: FairseqTask): super().__init__(task) - self.unk_str = task.target_dictionary.unk_string() if task.target_dictionary is not None else "UNK" self.sentence_avg = cfg.sentence_avg self.den_graph = k2.create_fsa_vec( [k2.Fsa.from_dict(torch.load(cfg.denominator_graph_path))] ) # has to be an FsaVec to be able to intersect with a batch of dense fsas if hasattr(self.den_graph, "aux_labels"): del self.den_graph.aux_labels - self.den_graph.scores.requires_grad_(False) - self.den_graph_cpu = self.den_graph[0].clone() # to be intersect with a individual numerator fsa - self.H_inv = k2.arc_sort(k2.invert(k2.Fsa.from_dict(torch.load(cfg.H_path)))) + self.den_graph.requires_grad_(False) + H_inv = k2.invert(k2.Fsa.from_dict(torch.load(cfg.H_path))) if cfg.L_path[-3:] == ".pt": - self.L_inv = k2.arc_sort(k2.Fsa.from_dict(torch.load(args.L_path)).invert_()) + L_inv = k2.Fsa.from_dict(torch.load(args.L_path)).invert_() else: with open(cfg.L_path, "r", encoding="utf-8") as f: - self.L_inv = k2.arc_sort(k2.Fsa.from_openfst(f.read(), acceptor=False).invert_()) - self.symbol_table = k2.SymbolTable.from_file(cfg.word_symbol_table_path) - phone_symbol_table = k2.SymbolTable.from_file(cfg.phone_symbol_table_path) - self.first_phone_disambig_id = min(v for k, v in phone_symbol_table._sym2id.items() if k.startswith("#")) + L_inv = k2.Fsa.from_openfst(f.read(), acceptor=False).invert_() + self.graph_compiler = TrainingGraphCompiler( + H_inv=H_inv, + L_inv=L_inv, + symbol_table=k2.SymbolTable.from_file(cfg.word_symbol_table_path), + phone_symbol_table=k2.SymbolTable.from_file(cfg.phone_symbol_table_path), + unk_str=task.target_dictionary.unk_string() if task.target_dictionary is not None else None, + den_graph=self.den_graph[0].clone(), # to be intersect with a single numerator fsa + ) + self.xent_regularize = cfg.xent_regularization_coefficient self.subsampling_factor = None @@ -160,11 +196,7 @@ def compute_loss( dense_fsa_vec = k2.DenseFsaVec(encoder_out, supervision_segments) # numerator computation - num_graphs = compile_numerator_graphs( - sample["target"]["text"], self.symbol_table, H_inv=self.H_inv, L_inv=self.L_inv, first_phone_disambig_id=self.first_phone_disambig_id, - unk_str=self.unk_str, den_graph=self.den_graph_cpu - ).to(encoder_out.device) - num_graphs.scores.requires_grad_(False) + num_graphs = self.graph_compiler.compile(sample["target"]["text"]).to(encoder_out.device) num_graphs_unrolled = k2.intersect_dense_pruned( num_graphs, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000 ) diff --git a/examples/mobvoihotwords/local/decode_best_path.py b/examples/mobvoihotwords/local/decode_best_path.py index 8e5d23a4f..3c5758f02 100755 --- a/examples/mobvoihotwords/local/decode_best_path.py +++ b/examples/mobvoihotwords/local/decode_best_path.py @@ -29,7 +29,7 @@ def get_parser(): ) # fmt: off parser.add_argument("--beam", type=float, default=10.0, help="decoding beam") - parser.add_argument("--word-symbol-table", type=str, help="path to the HCL fst file (torch_saved)", required=True) + parser.add_argument("--word-symbol-table", type=str, help="path to the word symbol table file", required=True) parser.add_argument("decoding_graph", type=str, default="data", help="path to the decoding graph") parser.add_argument("net_output", type=str, help="path to the network output file for acoustic scores") parser.add_argument("hyp_file", type=str, help="path to the resulting hypotheses file") diff --git a/examples/mobvoihotwords/run.sh b/examples/mobvoihotwords/run.sh index 336c8baa3..12cb912b3 100755 --- a/examples/mobvoihotwords/run.sh +++ b/examples/mobvoihotwords/run.sh @@ -206,11 +206,11 @@ if [ ${stage} -le 4 ]; then for wake_word in $wake_word0 $wake_word1; do if [[ "$wake_word" == "$wake_word0" ]]; then - wake_word0_cost_range="-1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0" - wake_word1_cost_range="0.0" + wake_word0_cost_range="3.0 4.0 5.0 10.0 15.0 20.0" + wake_word1_cost_range="2.0" else - wake_word0_cost_range="0.0" - wake_word1_cost_range="-1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0" + wake_word0_cost_range="2.0" + wake_word1_cost_range="3.0 4.0 5.0 10.0 15.0 20.0" fi for wake_word0_cost in $wake_word0_cost_range; do for wake_word1_cost in $wake_word1_cost_range; do