From 1f1255c53199bb4a9c8c1b15d4c68008b8f9e8c6 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 5 Jul 2021 12:00:44 +0800 Subject: [PATCH 1/9] update bpe models and integrate 4-gram rescore --- .../simple_v1/bpe_ctc_att_conformer_decode.py | 533 ++++++++++++++---- egs/librispeech/asr/simple_v1/bpe_run.sh | 81 ++- .../asr/simple_v1/generate_bpe_lexicon.py | 63 +++ 3 files changed, 575 insertions(+), 102 deletions(-) create mode 100644 egs/librispeech/asr/simple_v1/generate_bpe_lexicon.py diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index ffe2f08f..34a613f0 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -6,83 +6,301 @@ import argparse import logging import os -import random -import re -import sys +from collections import defaultdict from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional from typing import Union import k2 import numpy as np import torch -from snowfall.data import LibriSpeechAsrDataModule +from k2 import Fsa, SymbolTable + from snowfall.common import average_checkpoint, store_transcripts +from snowfall.common import find_first_disambig_symbol from snowfall.common import get_texts from snowfall.common import load_checkpoint from snowfall.common import str2bool from snowfall.common import write_error_stats +from snowfall.data import LibriSpeechAsrDataModule +from snowfall.decoding.graph import compile_HLG +from snowfall.decoding.lm_rescore import rescore_with_n_best_list +from snowfall.decoding.lm_rescore import rescore_with_whole_lattice +from snowfall.models import AcousticModel from snowfall.models.conformer import Conformer from snowfall.text.numericalizer import Numericalizer from snowfall.training.ctc_graph import build_ctc_topo +from snowfall.training.mmi_graph import get_phone_symbols + + +def nbest_decoding(lats: k2.Fsa, num_paths: int): + ''' + (Ideas of this function are from Dan) + + It implements something like CTC prefix beam search using n-best lists + + The basic idea is to first extra n-best paths from the given lattice, + build a word seqs from these paths, and compute the total scores + of these sequences in the log-semiring. The one with the max score + is used as the decoding output. + ''' + + # First, extract `num_paths` paths for each sequence. + # paths is a k2.RaggedInt with axes [seq][path][arc_pos] + paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + + # word_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains word IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + + word_seqs = k2.index(lats.aux_labels, paths) + # Note: the above operation supports also the case when + # lats.aux_labels is a ragged tensor. In that case, + # `remove_axis=True` is used inside the pybind11 binding code, + # so the resulting `word_seqs` still has 3 axes, like `paths`. + # The 3 axes are [seq][path][word] + + # Remove epsilons and -1 from word_seqs + word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) + + # Remove repeated sequences to avoid redundant computation later. + # + # Since k2.ragged.unique_sequences will reorder paths within a seq, + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + # new2old.numel() == unique_word_seqs.num_elements() + unique_word_seqs, _, new2old = k2.ragged.unique_sequences( + word_seqs, need_num_repeats=False, need_new2old_indexes=True) + # Note: unique_word_seqs still has the same axes as word_seqs + + seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) + + # path_to_seq_map is a 1-D torch.Tensor. + # path_to_seq_map[i] is the seq to which the i-th path + # belongs. + path_to_seq_map = seq_to_path_shape.row_ids(1) + + # Remove the seq axis. + # Now unique_word_seqs has only two axes [path][word] + unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) + + # word_fsas is an FsaVec with axes [path][state][arc] + word_fsas = k2.linear_fsa(unique_word_seqs) + + word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) + + # lats has phone IDs as labels and word IDs as aux_labels. + # inv_lats has word IDs as labels and phone IDs as aux_labels + inv_lats = k2.invert(lats) + inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted + + path_lats = k2.intersect_device(inv_lats, + word_fsas_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True) + # path_lats has word IDs as labels and phone IDs as aux_labels + + path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device)) + + tot_scores = path_lats.get_tot_scores(True, True) + # RaggedFloat currently supports float32 only. + # We may bind Ragged as RaggedDouble if needed. + ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, + tot_scores.to(torch.float32)) + + argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + + # Since we invoked `k2.ragged.unique_sequences`, which reorders + # the index from `paths`, we use `new2old` + # here to convert argmax_indexes to the indexes into `paths`. + # + # Use k2.index here since argmax_indexes' dtype is torch.int32 + best_path_indexes = k2.index(new2old, argmax_indexes) + + paths_2axes = k2.ragged.remove_axis(paths, 0) + + # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos] + best_paths = k2.index(paths_2axes, best_path_indexes) + + # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # Note that it contains -1s. + labels = k2.index(lats.labels.contiguous(), best_paths) + + labels = k2.ragged.remove_values_eq(labels, -1) + + # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so + # aux_labels is also a k2.RaggedInt with 2 axes + aux_labels = k2.index(lats.aux_labels, best_paths.values()) + + best_path_fsas = k2.linear_fsa(labels) + best_path_fsas.aux_labels = aux_labels + + return best_path_fsas +def decode_one_batch(batch: Dict[str, Any], + model: AcousticModel, + HLG: k2.Fsa, + output_beam_size: float, + num_paths: int, + use_whole_lattice: bool, + G: Optional[k2.Fsa] = None)->Dict[str, List[List[int]]]: + ''' + Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + model: + The neural network model. + HLG: + The decoding graph. + output_beam_size: + Size of the beam for pruning. + use_whole_lattice: + If True, `G` must not be None and it will use whole lattice for + LM rescoring. + If False and if `G` is not None, then `num_paths` must be positive + and it will use n-best list for LM rescoring. + num_paths: + It specifies the size of `n` in n-best list decoding. + G: + The LM. If it is None, no rescoring is used. + Otherwise, LM rescoring is used. + It supports two types of LM rescoring: n-best list rescoring + and whole lattice rescoring. + `use_whole_lattice` specifies which type to use. + + Returns: + Return the decoding result. See above description for the format of + the returned dict. + ''' + device = HLG.device + feature = batch['inputs'] + assert feature.ndim == 3 + feature = feature.to(device) + + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + + supervisions = batch['supervisions'] + + nnet_output, _, _ = model(feature, supervisions) + # nnet_output is [N, C, T] + + nnet_output = nnet_output.permute(0, 2, 1) + # now nnet_output is [N, T, C] + + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + (((supervisions['start_frame'] - 1) // 2 - 1) // 2), + (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), + 1).to(torch.int32) + + supervision_segments = torch.clamp(supervision_segments, min=0) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + + lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) + + if G is None: + if num_paths > 1: + best_paths = nbest_decoding(lattices, num_paths) + key=f'no_rescore-{num_paths}' + else: + key = 'no_rescore' + best_paths = k2.shortest_path(lattices, use_double_scores=True) + hyps = get_texts(best_paths, indices) + return {key: hyps} + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.45, 0.55, 0.65] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if use_whole_lattice: + best_paths_dict = rescore_with_whole_lattice(lattices, G, + lm_scale_list) + else: + best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths, + lm_scale_list) + # best_paths_dict is a dict + # - key: lm_scale_xxx, where xxx is the value of lm_scale. An example + # key is lm_scale_1.2 + # - value: it is the best path obtained using the corresponding lm scale + # from the dict key. + + ans = dict() + for lm_scale_str, best_paths in best_paths_dict.items(): + hyps = get_texts(best_paths, indices) + ans[lm_scale_str] = hyps + return ans + + +@torch.no_grad() def decode(dataloader: torch.utils.data.DataLoader, - model: None, - device: Union[str, torch.device], - ctc_topo: None, - numericalizer=None, - num_paths=-1, - output_beam_size: float=8): + model: AcousticModel, + HLG: Fsa, + symbols: SymbolTable, + num_paths: int, + G: k2.Fsa, + use_whole_lattice: bool, + output_beam_size: float): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 - results = [] + results = defaultdict(list) + # results is a dict whose keys and values are: + # - key: It indicates the lm_scale, e.g., lm_scale_1.2. + # If no rescoring is used, the key is the literal string: no_rescore + # + # - value: It is a list of tuples (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): - assert isinstance(batch, dict), type(batch) - feature = batch['inputs'] - supervisions = batch['supervisions'] - supervision_segments = torch.stack( - (supervisions['sequence_idx'], - (((supervisions['start_frame'] - 1) // 2 - 1) // 2), - (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) - supervision_segments = torch.clamp(supervision_segments, min=0) - indices = torch.argsort(supervision_segments[:, 2], descending=True) - supervision_segments = supervision_segments[indices] - texts = supervisions['text'] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - nnet_output = nnet_output.permute(0, 2, 1) - - # TODO(Liyong Guo): Tune this bias - # blank_bias = 0.0 - # nnet_output[:, :, 0] += blank_bias - - with torch.no_grad(): - dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) - - lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0, - output_beam_size, 30, 10000) - - best_paths = k2.shortest_path(lattices, use_double_scores=True) - hyps = get_texts(best_paths, indices) - assert len(hyps) == len(texts) + texts = batch['supervisions']['text'] + + hyps_dict = decode_one_batch(batch=batch, + model=model, + HLG=HLG, + output_beam_size=output_beam_size, + num_paths=num_paths, + use_whole_lattice=use_whole_lattice, + G=G) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + + for i in range(len(texts)): + hyp_words = [symbols.get(x) for x in hyps[i]] + ref_words = texts[i].split(' ') + this_batch.append((ref_words, hyp_words)) - for i in range(len(texts)): - pieces = [numericalizer.tokens_list[token_id] for token_id in hyps[i]] - hyp_words = numericalizer.tokenizer.DecodePieces(pieces).split(' ') - ref_words = texts[i].split(' ') - results.append((ref_words, hyp_words)) + results[lm_scale].extend(this_batch) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) + num_cuts += len(texts) + return results @@ -97,19 +315,45 @@ def get_parser(): default="conformer", choices=["conformer"], help="Model type.") - + parser.add_argument( + '--epoch', + type=int, + default=35, + help="Decoding epoch.") + parser.add_argument( + '--avg', + type=int, + default=10, + help="Number of checkpionts to average. Automaticly select " + "consecutive checkpoints before checkpoint specified by'--epoch'. ") + parser.add_argument( + '--att-rate', + type=float, + default=0.7, + help="Attention loss rate.") parser.add_argument( '--nhead', type=int, default=8, help="Number of attention heads in transformer.") - parser.add_argument( '--attention-dim', type=int, default=512, help="Number of units in transformer attention layers.") - + parser.add_argument( + '--output-beam-size', + type=float, + default=8, + help='Output beam size. Used in k2.intersect_dense_pruned.'\ + 'Choose a large value (e.g., 20), for 1-best decoding '\ + 'and n-best rescoring. Choose a small value (e.g., 8) for ' \ + 'rescoring with the whole lattice') + parser.add_argument( + '--use-lm-rescoring', + type=str2bool, + default=True, + help='When enabled, it uses LM for rescoring') parser.add_argument( '--num-paths', type=int, @@ -138,17 +382,6 @@ def get_parser(): default=False, help='When enabled, train an identical model to the espnet SOTA released model' "url: https://zenodo.org/record/4604066#.YNAAHmgzZPY") - parser.add_argument( - '--epoch', - type=int, - default=29, - help="Decoding epoch.") - parser.add_argument( - '--avg', - type=int, - default=5, - help="Number of checkpionts to average. Automaticly select " - "consecutive checkpoints before checkpoint specified by'--epoch'. ") parser.add_argument( '--generate-release-model', @@ -161,17 +394,14 @@ def get_parser(): type=str2bool, default=False, help='When enabled, decode and evaluate with the released model') - parser.add_argument( - '--att-rate', - type=float, - default=0.7, - help="Attention loss rate.") parser.add_argument( - '--output-beam-size', + '--lr-factor', type=float, - default=8, - help='Output beam size. Used in k2.intersect_dense_pruned.') + default=10.0, + help='Learning rate factor for Noam optimizer.' + ) + return parser @@ -187,22 +417,36 @@ def main(): att_rate = args.att_rate model_type = args.model_type epoch = args.epoch - + num_paths = args.num_paths + use_lm_rescoring = args.use_lm_rescoring + use_whole_lattice = False + if use_lm_rescoring and num_paths < 1: + # It doesn't make sense to use n-best list for rescoring + # when n is less than 1 + use_whole_lattice = True + + output_beam_size = args.output_beam_size + + # load L, G, symbol_table + logging.debug("About to load phone and word symbols") + lang_dir = Path('./data/lang_bpe2/') + symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + phone_ids = get_phone_symbols(phone_symbol_table) + + logging.debug("About to load model") # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N # device = torch.device('cuda', 1) device = torch.device('cuda') - lang_dir = Path('data/en_token_list/bpe_unigram5000/') - bpe_model_path = lang_dir / 'bpe.model' - tokens_file = lang_dir / 'tokens.txt' - numericalizer = Numericalizer.build_numericalizer(bpe_model_path, tokens_file) if att_rate != 0.0: num_decoder_layers = 6 else: num_decoder_layers = 0 - num_classes = len(numericalizer.tokens_list) + num_classes = len(phone_ids) + 1 + assert num_classes == 5000, print(num_classes) if model_type == "conformer": model = Conformer( num_features=80, @@ -220,7 +464,7 @@ def main(): else: raise NotImplementedError("Model of type " + str(model_type) + " is not verified") - exp_dir = Path(f'exp-bpe-{model_type}-{attention_dim}-{nhead}-noam/') + exp_dir = Path(f'exp-bpe-lrfactor{args.lr_factor}-{model_type}-{attention_dim}-{nhead}-noam/') if args.decode_with_released_model is True: released_model_path = exp_dir / f'model-epoch-{epoch}-avg-{avg}.pt' model.load_state_dict(torch.load(released_model_path)) @@ -251,31 +495,122 @@ def main(): logging.info("Loading pre-compiled ctc topo fst") d_ctc_topo = torch.load(ctc_path) ctc_topo = k2.Fsa.from_dict(d_ctc_topo) - ctc_topo = ctc_topo.to(device) - feature_dir = Path('exp/data') + if not os.path.exists(lang_dir / 'HLG.pt'): + logging.debug("Loading L_disambig.fst.txt") + with open(lang_dir / 'L_disambig.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + logging.debug("Loading G.fst.txt") + with open(lang_dir / 'G.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + HLG = compile_HLG(L=L, + G=G, + H=ctc_topo, + labels_disambig_id_start=first_phone_disambig_id, + aux_labels_disambig_id_start=first_word_disambig_id) + torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') + else: + logging.debug("Loading pre-compiled HLG") + d = torch.load(lang_dir / 'HLG.pt') + HLG = k2.Fsa.from_dict(d) + + if use_lm_rescoring: + if use_whole_lattice: + logging.info('Rescoring with the whole lattice') + else: + logging.info(f'Rescoring with n-best list, n is {num_paths}') + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + if not os.path.exists(lang_dir / 'G_4_gram.pt'): + logging.debug('Loading G_4_gram.fst.txt') + with open(lang_dir / 'G_4_gram.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION(fangjun): The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.create_fsa_vec([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt') + else: + logging.debug('Loading pre-compiled G_4_gram.pt') + d = torch.load(lang_dir / 'G_4_gram.pt') + G = k2.Fsa.from_dict(d).to(device) + + if use_whole_lattice: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + logging.debug('Decoding without LM rescoring') + G = None + if num_paths > 1: + logging.debug(f'Use n-best list decoding, n is {num_paths}') + else: + logging.debug('Use 1-best decoding') + + logging.debug("convert HLG to device") + HLG = HLG.to(device) + HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) + HLG.requires_grad_(False) + + if not hasattr(HLG, 'lm_scores'): + HLG.lm_scores = HLG.scores.clone() librispeech = LibriSpeechAsrDataModule(args) test_sets = ['test-clean', 'test-other'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): - results = decode(dataloader=test_dl, - model=model, - device=device, - ctc_topo=ctc_topo, - numericalizer=numericalizer, - num_paths=args.num_paths, - output_beam_size=args.output_beam_size) - - recog_path = exp_dir / f'recogs-{test_set}.txt' - store_transcripts(path=recog_path, texts=results) - logging.info(f'The transcripts are stored in {recog_path}') - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = exp_dir / f'errs-{test_set}.txt' - with open(errs_filename, 'w') as f: - write_error_stats(f, test_set, results) - logging.info('Wrote detailed error stats to {}'.format(errs_filename)) - -if __name__ == "__main__": + logging.info(f'* DECODING: {test_set}') + + test_set_wers = dict() + results_dict = decode(dataloader=test_dl, + model=model, + HLG=HLG, + symbols=symbol_table, + num_paths=num_paths, + G=G, + use_whole_lattice=use_whole_lattice, + output_beam_size=output_beam_size) + + for key, results in results_dict.items(): + recog_path = exp_dir / f'recogs-{test_set}-{key}.txt' + store_transcripts(path=recog_path, texts=results) + logging.info(f'The transcripts are stored in {recog_path}') + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f'errs-{test_set}-{key}.txt' + with open(errs_filename, 'w') as f: + wer = write_error_stats(f, f'{test_set}-{key}', results) + test_set_wers[key] = wer + + logging.info('Wrote detailed error stats to {}'.format(errs_filename)) + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = exp_dir / f'wer-summary-{test_set}.txt' + with open(errs_info, 'w') as f: + print('settings\tWER', file=f) + for key, val in test_set_wers: + print('{}\t{}'.format(key, val), file=f) + + s = '\nFor {}, WER of different settings are:\n'.format(test_set) + note = '\tbest for {}'.format(test_set) + for key, val in test_set_wers: + s += '{}\t{}{}\n'.format(key, val, note) + note='' + logging.info(s) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': main() diff --git a/egs/librispeech/asr/simple_v1/bpe_run.sh b/egs/librispeech/asr/simple_v1/bpe_run.sh index 82617946..3f982528 100644 --- a/egs/librispeech/asr/simple_v1/bpe_run.sh +++ b/egs/librispeech/asr/simple_v1/bpe_run.sh @@ -12,8 +12,12 @@ if [ $download_model -eq 1 ]; then if [ -d snowfall_bpe_model ]; then echo "Model seems already been dowloaded" else + if ! type git-lfs >/dev/null 2>&1; then + echo 'Please Install git-lfs to download trained models'; + exit 0 + fi git clone https://huggingface.co/GuoLiyong/snowfall_bpe_model - for sub_dir in data exp-bpe-conformer-512-8-noam; do + for sub_dir in data exp-bpe-lrfactor10.0-conformer-512-8-noam; do ln -sf snowfall_bpe_model/$sub_dir ./ done fi @@ -25,9 +29,80 @@ if [ ! -f exp/data/cuts_test-clean.json.gz ]; then fi if [ $stage -le 1 ]; then - export CUDA_VISIBLE_DEVICES=3 + local/download_lm.sh "openslr.org/resources/11" data/local/lm +fi + +if [ $stage -le 2 ]; then + dir=data/lang_bpe2 + mkdir -p $dir + token_file=./data/en_token_list/bpe_unigram5000/tokens.txt + model_file=./data/en_token_list/bpe_unigram5000/bpe.model + cp $token_file $dir/tokens.txt + ln -fv $dir/tokens.txt $dir/phones.txt + echo " 0" > $dir/words.txt + echo " 1" >> $dir/words.txt + cat data/local/lm/librispeech-vocab.txt | sort | uniq | + awk '{print $1 " " NR+1}' >> $dir/words.txt + + if [ ! -f $dir/lexicon.txt ]; then + python3 ./generate_bpe_lexicon.py \ + --model-file $model_file \ + --words-file $dir/words.txt > $dir/lexicon.txt + fi + + if [ ! -f $dir/lexiconp.txt ]; then + echo "**Creating $dir/lexiconp.txt from $dir/lexicon.txt" + perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $dir/lexicon.txt > $dir/lexiconp.txt || exit 1 + fi + + if ! grep "#0" $dir/words.txt > /dev/null 2>&1; then + max_word_id=$(tail -1 $dir/words.txt | awk '{print $2}') + echo "#0 $((max_word_id+1))" >> $dir/words.txt + fi + + ndisambig=$(local/add_lex_disambig.pl --pron-probs $dir/lexiconp.txt $dir/lexiconp_disambig.txt) + if ! grep "#0" $dir/phones.txt > /dev/null 2>&1 ; then + max_phone_id=$(tail -1 $dir/phones.txt | awk '{print $2}') + for i in $(seq 0 $ndisambig); do + echo "#$i $((i+max_phone_id+1))" + done >> $dir/phones.txt + fi + + if [ ! -f $dir/L_disambig.fst.txt ]; then + wdisambig_phone=$(echo "#0" | local/sym2int.pl $dir/phones.txt) + wdisambig_word=$(echo "#0" | local/sym2int.pl $dir/words.txt) + + local/make_lexicon_fst.py \ + $dir/lexiconp_disambig.txt | \ + local/sym2int.pl --map-oov 1 -f 3 $dir/phones.txt | \ + local/sym2int.pl -f 4 $dir/words.txt | \ + local/fstaddselfloops.pl $wdisambig_phone $wdisambig_word > $dir/L_disambig.fst.txt || exit 1 + fi + + if [ ! -f $dir/G.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/local/lm/lm_tgmed.arpa > $dir/G.fst.txt + else + echo "Skip generating $dir/G.fst.txt" + fi + if [ ! -f $dir/G_4_gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/local/lm/lm_fglarge.arpa >$dir/G_4_gram.fst.txt + else + echo "Skip generating data/lang_nosp/G_4_gram.fst.txt" + fi +fi + +if [ $stage -le 3 ]; then + export CUDA_VISIBLE_DEVICES=2 python bpe_ctc_att_conformer_decode.py \ - --max-duration=10 \ + --max-duration=5 \ --generate-release-model=False \ --decode_with_released_model=True fi diff --git a/egs/librispeech/asr/simple_v1/generate_bpe_lexicon.py b/egs/librispeech/asr/simple_v1/generate_bpe_lexicon.py new file mode 100644 index 00000000..67056e11 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/generate_bpe_lexicon.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from pathlib import Path +from typing import List + +import argparse +import sentencepiece as spm + + +def read_words(words_txt: str, excluded=['', '']) -> List[str]: + '''Read words_txt and return a list of words. + The file words_txt has the following format: + + That is, every line has two fields. This function + extracts the first field. + Args: + words_txt: + Filename of words.txt. + excluded: + words in this list are not returned. + Returns: + Return a list of words. + ''' + ans = [] + with open(words_txt, 'r', encoding='latin-1') as f: + for line in f: + word, _ = line.strip().split() + if word not in excluded: + ans.append(word) + return ans + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--model-file', + type=str, + help='Pre-trained BPE model file') + + parser.add_argument('--words-file', type=str, help='Path to words.txt') + + args = parser.parse_args() + model_file = args.model_file + words_txt = args.words_file + assert Path(model_file).is_file(), f'{model_file} does not exist' + assert Path(words_txt).is_file(), f'{words_txt} does not exist' + + words = read_words(words_txt) + + sp = spm.SentencePieceProcessor() + sp.load(model_file) + + for word in words: + pieces = sp.EncodeAsPieces(word.upper()) + print(word, ' '.join(pieces)) + + print('', '') + + +if __name__ == '__main__': + main() From d90250b3859c13cf6c99e7bbf56c36ea5f6a4d8a Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Tue, 6 Jul 2021 10:37:04 +0800 Subject: [PATCH 2/9] transformer decoder n-best rescore with batch_size=1 --- .../simple_v1/bpe_ctc_att_conformer_decode.py | 195 ++++++++---------- egs/librispeech/asr/simple_v1/bpe_run.sh | 20 +- snowfall/decoding/lm_rescore.py | 8 +- snowfall/models/transformer.py | 58 +++++- 4 files changed, 159 insertions(+), 122 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index 34a613f0..e8a7f3ab 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -31,6 +31,7 @@ from snowfall.decoding.graph import compile_HLG from snowfall.decoding.lm_rescore import rescore_with_n_best_list from snowfall.decoding.lm_rescore import rescore_with_whole_lattice +from snowfall.decoding.lm_rescore import compute_am_scores, _intersect_device from snowfall.models import AcousticModel from snowfall.models.conformer import Conformer from snowfall.text.numericalizer import Numericalizer @@ -38,27 +39,29 @@ from snowfall.training.mmi_graph import get_phone_symbols -def nbest_decoding(lats: k2.Fsa, num_paths: int): +def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: int): ''' - (Ideas of this function are from Dan) - - It implements something like CTC prefix beam search using n-best lists - - The basic idea is to first extra n-best paths from the given lattice, - build a word seqs from these paths, and compute the total scores - of these sequences in the log-semiring. The one with the max score - is used as the decoding output. + N-best rescore with transformer-decoder model. + The basic idea is to first extra n-best paths from the given lattice. + Then extract word_seqs and token_seqs for each path. + Compute the negative log-likehood for each token_seq as 'language model score', called decoder_scores. + Compute am score for each token_seq. + Total scores is a weight sum of am_score and decoder_scores. + The one with the max total score is used as the decoding output. ''' + # lats has token IDs as labels + # and word IDs as aux_labels. # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) - # word_seqs is a k2.RaggedInt sharing the same shape as `paths` + # token_seqs/word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. + token_seqs = k2.index(lats.labels.contiguous(), paths) + word_seqs = k2.index(lats.aux_labels.contiguous(), paths) - word_seqs = k2.index(lats.aux_labels, paths) # Note: the above operation supports also the case when # lats.aux_labels is a ragged tensor. In that case, # `remove_axis=True` is used inside the pybind11 binding code, @@ -66,6 +69,7 @@ def nbest_decoding(lats: k2.Fsa, num_paths: int): # The 3 axes are [seq][path][word] # Remove epsilons and -1 from word_seqs + token_seqs = k2.ragged.remove_values_leq(token_seqs, 0) word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. @@ -74,11 +78,11 @@ def nbest_decoding(lats: k2.Fsa, num_paths: int): # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() - unique_word_seqs, _, new2old = k2.ragged.unique_sequences( - word_seqs, need_num_repeats=False, need_new2old_indexes=True) - # Note: unique_word_seqs still has the same axes as word_seqs + unique_token_seqs, _, new2old = k2.ragged.unique_sequences( + token_seqs, need_num_repeats=False, need_new2old_indexes=True) + # Note: unique_token_seqs still has the same axes as token_seqs - seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) + seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -87,66 +91,54 @@ def nbest_decoding(lats: k2.Fsa, num_paths: int): # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] - unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) + unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) - # word_fsas is an FsaVec with axes [path][state][arc] - word_fsas = k2.linear_fsa(unique_word_seqs) + # token_fsas is an FsaVec with axes [path][state][arc] + token_fsas = k2.linear_fsa(unique_token_seqs) - word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) + token_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(token_fsas) - # lats has phone IDs as labels and word IDs as aux_labels. - # inv_lats has word IDs as labels and phone IDs as aux_labels + # lats has token IDs as labels and word IDs as aux_labels. + # inv_lats has word IDs as labels and token IDs as aux_labels + # Do k2.invert to make it compatible to function compute_am_scores inv_lats = k2.invert(lats) inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted + am_scores = compute_am_scores(inv_lats, token_fsas_with_epsilon_loops, path_to_seq_map) + + lats = k2.arc_sort(lats) + fgram_lm_lats = _intersect_device(lats, token_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) + fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats.to('cpu')).to(lats.device)) + # am_scores is computed with log_semiring=True + # set log_semiring=True here to make fgram_lm_scores comparable to am_scores + fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=True) + fgram_lm_scores = fgram_tot_scores - am_scores + + + # now compute lm scores from transformer decoder + token_ids = k2.ragged.to_list(unique_token_seqs) + num_seqs = len(token_ids) + time_steps = encoder_memory.shape[0] + feature_dim = encoder_memory.shape[2] + encoder_memory = encoder_memory.expand(time_steps, num_seqs, feature_dim) + memory_mask = memory_mask.expand(num_seqs, time_steps) + + # nll: negative log-likelihood + nll = model.decoder_nll(encoder_memory, memory_mask, token_ids=token_ids) + assert nll.shape[0] == num_seqs + decoder_scores = - nll.sum(dim=1) + tot_scores = am_scores + fgram_lm_scores + decoder_scores + best_seq_idx = new2old[torch.argmax(tot_scores)] + best_word_seq = [k2.ragged.to_list(word_seqs)[0][best_seq_idx]] + + return best_word_seq - path_lats = k2.intersect_device(inv_lats, - word_fsas_with_epsilon_loops, - b_to_a_map=path_to_seq_map, - sorted_match_a=True) - # path_lats has word IDs as labels and phone IDs as aux_labels - - path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device)) - - tot_scores = path_lats.get_tot_scores(True, True) - # RaggedFloat currently supports float32 only. - # We may bind Ragged as RaggedDouble if needed. - ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, - tot_scores.to(torch.float32)) - - argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) - - # Since we invoked `k2.ragged.unique_sequences`, which reorders - # the index from `paths`, we use `new2old` - # here to convert argmax_indexes to the indexes into `paths`. - # - # Use k2.index here since argmax_indexes' dtype is torch.int32 - best_path_indexes = k2.index(new2old, argmax_indexes) - - paths_2axes = k2.ragged.remove_axis(paths, 0) - - # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos] - best_paths = k2.index(paths_2axes, best_path_indexes) - - # labels is a k2.RaggedInt with 2 axes [path][phone_id] - # Note that it contains -1s. - labels = k2.index(lats.labels.contiguous(), best_paths) - - labels = k2.ragged.remove_values_eq(labels, -1) - - # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so - # aux_labels is also a k2.RaggedInt with 2 axes - aux_labels = k2.index(lats.aux_labels, best_paths.values()) - - best_path_fsas = k2.linear_fsa(labels) - best_path_fsas.aux_labels = aux_labels - - return best_path_fsas def decode_one_batch(batch: Dict[str, Any], model: AcousticModel, HLG: k2.Fsa, output_beam_size: float, num_paths: int, use_whole_lattice: bool, + nbest_rescore_with_decoder: bool = True, G: Optional[k2.Fsa] = None)->Dict[str, List[List[int]]]: ''' Decode one batch and return the result in a dict. The dict has the @@ -178,7 +170,7 @@ def decode_one_batch(batch: Dict[str, Any], If False and if `G` is not None, then `num_paths` must be positive and it will use n-best list for LM rescoring. num_paths: - It specifies the size of `n` in n-best list decoding. + It specifies the size of `n` in n-best list decoding with transforer decoder model. G: The LM. If it is None, no rescoring is used. Otherwise, LM rescoring is used. @@ -194,13 +186,15 @@ def decode_one_batch(batch: Dict[str, Any], feature = batch['inputs'] assert feature.ndim == 3 feature = feature.to(device) + batch_size = feature.shape[0] + assert batch_size == 1, 'Currently only surrort batch_size=1' # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] supervisions = batch['supervisions'] - nnet_output, _, _ = model(feature, supervisions) + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) @@ -220,36 +214,27 @@ def decode_one_batch(batch: Dict[str, Any], lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) - if G is None: - if num_paths > 1: - best_paths = nbest_decoding(lattices, num_paths) - key=f'no_rescore-{num_paths}' - else: - key = 'no_rescore' - best_paths = k2.shortest_path(lattices, use_double_scores=True) - hyps = get_texts(best_paths, indices) - return {key: hyps} + # TODO(Guo Liyong): figure out a way to combine lm_scale_list with transformer decoder n-best rescore + # lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + # lm_scale_list += [0.45, 0.55, 0.65] + # lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + # lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.45, 0.55, 0.65] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + lm_scale_list = [0.6] # lowest wer = 2.92 without transformer n-best rescore if use_whole_lattice: best_paths_dict = rescore_with_whole_lattice(lattices, G, - lm_scale_list) - else: - best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths, - lm_scale_list) - # best_paths_dict is a dict - # - key: lm_scale_xxx, where xxx is the value of lm_scale. An example - # key is lm_scale_1.2 - # - value: it is the best path obtained using the corresponding lm scale - # from the dict key. - + lm_scale_list, + need_rescored_lats=True) ans = dict() - for lm_scale_str, best_paths in best_paths_dict.items(): - hyps = get_texts(best_paths, indices) + for lm_scale_str, (best_paths, ngram_rescored_lattices) in best_paths_dict.items(): + assert best_paths.shape[0] == 1, 'Figuring out a way to do batch decoding' + if nbest_rescore_with_decoder: + best_word_seq = nbest_decoding(model, encoder_memory, memory_mask, ngram_rescored_lattices, num_paths) + hyps = best_word_seq + + else: + hyps = get_texts(best_paths, indices) ans[lm_scale_str] = hyps return ans @@ -355,13 +340,10 @@ def get_parser(): default=True, help='When enabled, it uses LM for rescoring') parser.add_argument( - '--num-paths', + '--num-paths-for-decoder-rescore', type=int, - default=-1, - help='Number of paths for rescoring using n-best list.' \ - 'If it is negative, then rescore with the whole lattice.'\ - 'CAUTION: You have to reduce max_duration in case of CUDA OOM' - ) + default=500, + help='Number of paths for rescoring using n-best list with transformer decoder model.') parser.add_argument( '--is-espnet-structure', @@ -417,13 +399,8 @@ def main(): att_rate = args.att_rate model_type = args.model_type epoch = args.epoch - num_paths = args.num_paths use_lm_rescoring = args.use_lm_rescoring - use_whole_lattice = False - if use_lm_rescoring and num_paths < 1: - # It doesn't make sense to use n-best list for rescoring - # when n is less than 1 - use_whole_lattice = True + use_whole_lattice = True output_beam_size = args.output_beam_size @@ -519,8 +496,6 @@ def main(): if use_lm_rescoring: if use_whole_lattice: logging.info('Rescoring with the whole lattice') - else: - logging.info(f'Rescoring with n-best list, n is {num_paths}') first_word_disambig_id = find_first_disambig_symbol(symbol_table) if not os.path.exists(lang_dir / 'G_4_gram.pt'): logging.debug('Loading G_4_gram.fst.txt') @@ -551,12 +526,8 @@ def main(): # LM rescoring. G.lm_scores = G.scores.clone() else: - logging.debug('Decoding without LM rescoring') - G = None - if num_paths > 1: - logging.debug(f'Use n-best list decoding, n is {num_paths}') - else: - logging.debug('Use 1-best decoding') + logging.debug('Currently 4-gram lattice rescore is required.') + sys.exit() logging.debug("convert HLG to device") HLG = HLG.to(device) @@ -569,6 +540,8 @@ def main(): librispeech = LibriSpeechAsrDataModule(args) test_sets = ['test-clean', 'test-other'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + # if 'test-clean' == test_set: + # continue logging.info(f'* DECODING: {test_set}') test_set_wers = dict() @@ -576,7 +549,7 @@ def main(): model=model, HLG=HLG, symbols=symbol_table, - num_paths=num_paths, + num_paths=args.num_paths_for_decoder_rescore, G=G, use_whole_lattice=use_whole_lattice, output_beam_size=output_beam_size) diff --git a/egs/librispeech/asr/simple_v1/bpe_run.sh b/egs/librispeech/asr/simple_v1/bpe_run.sh index 3f982528..b1c850ed 100644 --- a/egs/librispeech/asr/simple_v1/bpe_run.sh +++ b/egs/librispeech/asr/simple_v1/bpe_run.sh @@ -37,12 +37,14 @@ if [ $stage -le 2 ]; then mkdir -p $dir token_file=./data/en_token_list/bpe_unigram5000/tokens.txt model_file=./data/en_token_list/bpe_unigram5000/bpe.model - cp $token_file $dir/tokens.txt - ln -fv $dir/tokens.txt $dir/phones.txt - echo " 0" > $dir/words.txt - echo " 1" >> $dir/words.txt - cat data/local/lm/librispeech-vocab.txt | sort | uniq | - awk '{print $1 " " NR+1}' >> $dir/words.txt + if [ ! -f $dir/tokens.txt ]; then + cp $token_file $dir/tokens.txt + ln -fv $dir/tokens.txt $dir/phones.txt + echo " 0" > $dir/words.txt + echo " 1" >> $dir/words.txt + cat data/local/lm/librispeech-vocab.txt | sort | uniq | + awk '{print $1 " " NR+1}' >> $dir/words.txt + fi if [ ! -f $dir/lexicon.txt ]; then python3 ./generate_bpe_lexicon.py \ @@ -101,8 +103,10 @@ fi if [ $stage -le 3 ]; then export CUDA_VISIBLE_DEVICES=2 + # Set max-duration=1 because rescore with decoder only support batch_size=1 python bpe_ctc_att_conformer_decode.py \ - --max-duration=5 \ + --max-duration=1 \ --generate-release-model=False \ - --decode_with_released_model=True + --decode_with_released_model=True \ + --num-paths-for-decoder-rescore=500 fi diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index 1edef233..cf7ed20a 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -244,7 +244,8 @@ def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, num_paths: int, @torch.no_grad() def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, - lm_scale_list: List[float] + lm_scale_list: List[float], + need_rescored_lats: bool = False, ) -> Dict[str, k2.Fsa]: '''Use whole lattice to rescore. @@ -319,5 +320,8 @@ def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, best_paths = k2.shortest_path(inv_lats, use_double_scores=True) key = f'lm_scale_{lm_scale}' - ans[key] = best_paths + if need_rescored_lats: + ans[key] = (best_paths, inv_lats) + else: + ans[key] = best_paths return ans diff --git a/snowfall/models/transformer.py b/snowfall/models/transformer.py index 6f50da95..c440f4a6 100644 --- a/snowfall/models/transformer.py +++ b/snowfall/models/transformer.py @@ -165,7 +165,7 @@ def decoder_forward(self, x: Tensor, encoder_mask: Tensor, supervision: Supervis ys_in = [torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids] ys_out = [torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids] ys_in_pad = pad_list(ys_in, eos_id) - ys_out_pad = pad_list(ys_in, -1) + ys_out_pad = pad_list(ys_out, -1) else: raise ValueError("Invalid input for decoder self attetion") @@ -193,6 +193,62 @@ def decoder_forward(self, x: Tensor, encoder_mask: Tensor, supervision: Supervis return decoder_loss + def decoder_nll(self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = None) -> Tensor: + """ + Args: + x: encoder-output, Tensor of dimension (input_length, batch_size, d_model). + encoder_mask: Mask tensor of dimension (batch_size, input_length) + token_ids: n-best list extracted from lattice before rescore + + Returns: + Tensor: negative log-likelihood. + """ + # The common part between this fuction and decoder_forward could be + # extracted as a seperated function. + if token_ids is not None: + # speical token ids: + # 0 + # 1 + # self.decoder_num_class - 1 + sos_id = self.decoder_num_class - 1 + eos_id = self.decoder_num_class - 1 + _sos = torch.tensor([sos_id]) + _eos = torch.tensor([eos_id]) + ys_in = [torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids] + ys_out = [torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids] + ys_in_pad = pad_list(ys_in, eos_id) + else: + raise ValueError("Invalid input for decoder self attetion") + + + ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(x.device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder(tgt=tgt, + memory=x, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=encoder_mask) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction='none') + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + class TransformerEncoderLayer(nn.Module): """ From 9ea46edf4e77f8b74e78af68c6e60070ca94dfb2 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 7 Jul 2021 11:10:50 +0800 Subject: [PATCH 3/9] fix typo and use log_semiring=False --- .../asr/simple_v1/bpe_ctc_att_conformer_decode.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index e8a7f3ab..19c0d68f 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -107,10 +107,9 @@ def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: lats = k2.arc_sort(lats) fgram_lm_lats = _intersect_device(lats, token_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) - fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats.to('cpu')).to(lats.device)) - # am_scores is computed with log_semiring=True - # set log_semiring=True here to make fgram_lm_scores comparable to am_scores - fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=True) + fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats)) + # log_semiring=False is a little better than log_semiring=True. + fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False) fgram_lm_scores = fgram_tot_scores - am_scores @@ -309,7 +308,7 @@ def get_parser(): '--avg', type=int, default=10, - help="Number of checkpionts to average. Automaticly select " + help="Number of checkpionts to average. Automatically select " "consecutive checkpoints before checkpoint specified by'--epoch'. ") parser.add_argument( '--att-rate', @@ -540,8 +539,6 @@ def main(): librispeech = LibriSpeechAsrDataModule(args) test_sets = ['test-clean', 'test-other'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): - # if 'test-clean' == test_set: - # continue logging.info(f'* DECODING: {test_set}') test_set_wers = dict() From 6d1e935ff3148a81af3586c4827b3c899c7fffad Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 7 Jul 2021 16:13:40 +0800 Subject: [PATCH 4/9] use word_seqs to rescore --- .../simple_v1/bpe_ctc_att_conformer_decode.py | 33 +++++++++++-------- snowfall/decoding/lm_rescore.py | 3 +- snowfall/models/transformer.py | 1 + 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index 19c0d68f..48a4d429 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -78,11 +78,11 @@ def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() - unique_token_seqs, _, new2old = k2.ragged.unique_sequences( - token_seqs, need_num_repeats=False, need_new2old_indexes=True) - # Note: unique_token_seqs still has the same axes as token_seqs + unique_word_seqs, _, new2old = k2.ragged.unique_sequences( + word_seqs, need_num_repeats=False, need_new2old_indexes=True) + # Note: unique_word_seqs still has the same axes as word_seqs - seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) + seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path @@ -91,22 +91,24 @@ def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] - unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) + unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) - # token_fsas is an FsaVec with axes [path][state][arc] - token_fsas = k2.linear_fsa(unique_token_seqs) + # word_fsas is an FsaVec with axes [path][state][arc] + word_fsas = k2.linear_fsa(unique_word_seqs) - token_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(token_fsas) + word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) + + am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map) + import pdb; pdb.set_trace() # lats has token IDs as labels and word IDs as aux_labels. # inv_lats has word IDs as labels and token IDs as aux_labels # Do k2.invert to make it compatible to function compute_am_scores - inv_lats = k2.invert(lats) - inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted - am_scores = compute_am_scores(inv_lats, token_fsas_with_epsilon_loops, path_to_seq_map) + # inv_lats = k2.invert(lats) + inv_lats = k2.arc_sort(k2.invert(lats)) # no-op if inv_lats is already arc-sorted - lats = k2.arc_sort(lats) - fgram_lm_lats = _intersect_device(lats, token_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) + # lats = k2.arc_sort(lats) + fgram_lm_lats = _intersect_device(inv_lats, word_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats)) # log_semiring=False is a little better than log_semiring=True. fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False) @@ -114,7 +116,10 @@ def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: # now compute lm scores from transformer decoder - token_ids = k2.ragged.to_list(unique_token_seqs) + # Now token_seqs has only two axes [path][word] + token_seqs = k2.ragged.remove_axis(token_seqs, 0) + token_ids, _ = k2.ragged.index(token_seqs, new2old, axis=0) + token_ids = k2.ragged.to_list(token_ids) num_seqs = len(token_ids) time_steps = encoder_memory.shape[0] feature_dim = encoder_memory.shape[2] diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index cf7ed20a..6d153c53 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -93,8 +93,7 @@ def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, b_to_a_map=path_to_seq_map, sorted_match_a=True) - # NOTE: `k2.connect` supports only CPU at present - am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu')).to(device)) + am_path_lats = k2.top_sort(k2.connect(am_path_lats)) # The `scores` of every arc consists of `am_scores` and `lm_scores` am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores diff --git a/snowfall/models/transformer.py b/snowfall/models/transformer.py index c440f4a6..a15dd70d 100644 --- a/snowfall/models/transformer.py +++ b/snowfall/models/transformer.py @@ -217,6 +217,7 @@ def decoder_nll(self, x: Tensor, encoder_mask: Tensor, token_ids: List[int] = No ys_in = [torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids] ys_out = [torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids] ys_in_pad = pad_list(ys_in, eos_id) + ys_out_pad = pad_list(ys_out, -1) else: raise ValueError("Invalid input for decoder self attetion") From 5c979cce1b6a9c9bf72ec484746143b321ae73a7 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Thu, 8 Jul 2021 15:37:12 +0800 Subject: [PATCH 5/9] compare nbest_rescore result between unique_word_seqs and unique_token_seqs --- .../simple_v1/bpe_ctc_att_conformer_decode.py | 155 ++++++++++-------- snowfall/decoding/lm_rescore.py | 10 +- 2 files changed, 89 insertions(+), 76 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index 48a4d429..9e9d710c 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -39,87 +39,103 @@ from snowfall.training.mmi_graph import get_phone_symbols -def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: int): - ''' - N-best rescore with transformer-decoder model. - The basic idea is to first extra n-best paths from the given lattice. - Then extract word_seqs and token_seqs for each path. - Compute the negative log-likehood for each token_seq as 'language model score', called decoder_scores. - Compute am score for each token_seq. - Total scores is a weight sum of am_score and decoder_scores. - The one with the max total score is used as the decoding output. - ''' - +def extract_nbest_list(lats: k2.Fsa, num_paths: int): # lats has token IDs as labels # and word IDs as aux_labels. # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) - - # token_seqs/word_seqs is a k2.RaggedInt sharing the same shape as `paths` - # but it contains word IDs. Note that it also contains 0s and -1s. + # Both token_seqs and word_seqs are k2.RaggedInt sharing the same shape as `paths` + # Note that they also contain 0s and -1s. # The last entry in each sublist is -1. token_seqs = k2.index(lats.labels.contiguous(), paths) word_seqs = k2.index(lats.aux_labels.contiguous(), paths) - # Note: the above operation supports also the case when - # lats.aux_labels is a ragged tensor. In that case, - # `remove_axis=True` is used inside the pybind11 binding code, - # so the resulting `word_seqs` still has 3 axes, like `paths`. - # The 3 axes are [seq][path][word] - - # Remove epsilons and -1 from word_seqs token_seqs = k2.ragged.remove_values_leq(token_seqs, 0) word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) + return token_seqs, word_seqs - # Remove repeated sequences to avoid redundant computation later. - # - # Since k2.ragged.unique_sequences will reorder paths within a seq, - # `new2old` is a 1-D torch.Tensor mapping from the output path index - # to the input path index. - # new2old.numel() == unique_word_seqs.num_elements() +def compute_am_flm_scrores_1(lats, word_seqs, token_seqs): + ''' + Compute am scores with word_seqs + wer is worse than compute_am_flm_scores_2 + ''' + # lats has token IDs as labels and word IDs as aux_labels. unique_word_seqs, _, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=False, need_new2old_indexes=True) - # Note: unique_word_seqs still has the same axes as word_seqs seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) - - # path_to_seq_map is a 1-D torch.Tensor. - # path_to_seq_map[i] is the seq to which the i-th path - # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) - - # Remove the seq axis. - # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) - # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) - word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) - am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map) - import pdb; pdb.set_trace() - # lats has token IDs as labels and word IDs as aux_labels. # inv_lats has word IDs as labels and token IDs as aux_labels - # Do k2.invert to make it compatible to function compute_am_scores - # inv_lats = k2.invert(lats) inv_lats = k2.arc_sort(k2.invert(lats)) # no-op if inv_lats is already arc-sorted - # lats = k2.arc_sort(lats) fgram_lm_lats = _intersect_device(inv_lats, word_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats)) + # log_semiring=False is a little better than log_semiring=True. fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False) fgram_lm_scores = fgram_tot_scores - am_scores + # Now token_seqs has only two axes [path][word] + token_seqs = k2.ragged.remove_axis(token_seqs, 0) + token_ids, _ = k2.ragged.index(token_seqs, new2old, axis=0) + token_ids = k2.ragged.to_list(token_ids) + return am_scores, fgram_lm_scores, token_ids, new2old + +def compute_am_flm_scrores_2(lats, word_seqs, token_seqs): + ''' + Compute am scores with token_seqs + wer is better than compute_am_flm_scores_1 + ''' + unique_token_seqs, _, new2old = k2.ragged.unique_sequences( + token_seqs, need_num_repeats=False, need_new2old_indexes=True) + + seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) + path_to_seq_map = seq_to_path_shape.row_ids(1) + + unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) + # token_fsas is an FsaVec with axes [path][state][arc] + token_fsas = k2.linear_fsa(unique_token_seqs) + token_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(token_fsas) + # lats has token IDs as labels and word IDs as aux_labels. + # inv_lats has word IDs as labels and token IDs as aux_labels + am_scores = compute_am_scores(k2.arc_sort(k2.invert(lats)), token_fsas_with_epsilon_loops, path_to_seq_map) + + fgram_lm_lats = _intersect_device(k2.arc_sort(lats), token_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) + fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats)) + + # log_semiring=False is a little better than log_semiring=True. + fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False) + fgram_lm_scores = fgram_tot_scores - am_scores - # now compute lm scores from transformer decoder # Now token_seqs has only two axes [path][word] token_seqs = k2.ragged.remove_axis(token_seqs, 0) token_ids, _ = k2.ragged.index(token_seqs, new2old, axis=0) token_ids = k2.ragged.to_list(token_ids) + return am_scores, fgram_lm_scores, token_ids, new2old + +def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: int): + ''' + N-best rescore with transformer-decoder model. + The basic idea is to first extra n-best paths from the given lattice. + Then extract word_seqs and token_seqs for each path. + Compute the negative log-likehood for each token_seq as 'language model score', called decoder_scores. + Compute am score for each token_seq. + Total scores is a weight sum of am_score and decoder_scores. + The one with the max total score is used as the decoding output. + ''' + # token_seqs, word_seqs, unique_token_seqs, unique_word_seqs = extract_nbest_list(lats, num_paths) + token_seqs, word_seqs = extract_nbest_list(lats, num_paths) + + # am_scores, fgram_lm_scores, token_ids, new2old = compute_am_flm_scrores_1(lats, word_seqs, token_seqs) + am_scores, fgram_lm_scores, token_ids, new2old = compute_am_flm_scrores_2(lats, word_seqs, token_seqs) + # now compute lm scores from transformer decoder num_seqs = len(token_ids) time_steps = encoder_memory.shape[0] feature_dim = encoder_memory.shape[2] @@ -130,11 +146,22 @@ def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: nll = model.decoder_nll(encoder_memory, memory_mask, token_ids=token_ids) assert nll.shape[0] == num_seqs decoder_scores = - nll.sum(dim=1) - tot_scores = am_scores + fgram_lm_scores + decoder_scores - best_seq_idx = new2old[torch.argmax(tot_scores)] - best_word_seq = [k2.ragged.to_list(word_seqs)[0][best_seq_idx]] - return best_word_seq + flm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0, 4.0, 6.0, 8.0, 10.0] + + decoder_scale_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0, 2.0, 4.0, 6.0, 8.0, 10.0] + decoder_scale_list += [0.01, 0.03, 0.05, 0.08, 0.09] + + ans = dict() + for flm_scale in flm_scale_list: + for decoder_scale in decoder_scale_list: + key = f'lm_scale_{flm_scale}_decoder_scale_{decoder_scale}' + tot_scores = am_scores + flm_scale * fgram_lm_scores + decoder_scale * decoder_scores + best_seq_idx = new2old[torch.argmax(tot_scores)] + best_word_seq = [k2.ragged.to_list(word_seqs)[0][best_seq_idx]] + ans[key] = best_word_seq + + return ans def decode_one_batch(batch: Dict[str, Any], model: AcousticModel, @@ -218,28 +245,11 @@ def decode_one_batch(batch: Dict[str, Any], lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) - # TODO(Guo Liyong): figure out a way to combine lm_scale_list with transformer decoder n-best rescore - # lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - # lm_scale_list += [0.45, 0.55, 0.65] - # lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - # lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - lm_scale_list = [0.6] # lowest wer = 2.92 without transformer n-best rescore - - if use_whole_lattice: - best_paths_dict = rescore_with_whole_lattice(lattices, G, - lm_scale_list, - need_rescored_lats=True) - ans = dict() - for lm_scale_str, (best_paths, ngram_rescored_lattices) in best_paths_dict.items(): - assert best_paths.shape[0] == 1, 'Figuring out a way to do batch decoding' - if nbest_rescore_with_decoder: - best_word_seq = nbest_decoding(model, encoder_memory, memory_mask, ngram_rescored_lattices, num_paths) - hyps = best_word_seq - - else: - hyps = get_texts(best_paths, indices) - ans[lm_scale_str] = hyps + # fgram means four-gram + fgram_rescored_lattices = rescore_with_whole_lattice(lattices, G, + lm_scale_list=None, + need_rescored_lats=True) + ans = nbest_decoding(model, encoder_memory, memory_mask, fgram_rescored_lattices, num_paths) return ans @@ -252,6 +262,8 @@ def decode(dataloader: torch.utils.data.DataLoader, G: k2.Fsa, use_whole_lattice: bool, output_beam_size: float): + del HLG.lm_scores + HLG.lm_scores = HLG.scores.clone() tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = defaultdict(list) @@ -542,7 +554,8 @@ def main(): HLG.lm_scores = HLG.scores.clone() librispeech = LibriSpeechAsrDataModule(args) - test_sets = ['test-clean', 'test-other'] + # test_sets = ['test-clean', 'test-other'] + test_sets = ['test-clean'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): logging.info(f'* DECODING: {test_set}') diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index 6d153c53..caa5d01d 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -299,12 +299,15 @@ def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, b_to_a_map, sorted_match_a=True) - rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device)) + rescoring_lats = k2.top_sort(k2.connect(rescoring_lats)) # inv_lats has phone IDs as labels # and word IDs as aux_labels. inv_lats = k2.invert(rescoring_lats) + if need_rescored_lats: + return inv_lats + ans = dict() # # The following implements @@ -319,8 +322,5 @@ def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, best_paths = k2.shortest_path(inv_lats, use_double_scores=True) key = f'lm_scale_{lm_scale}' - if need_rescored_lats: - ans[key] = (best_paths, inv_lats) - else: - ans[key] = best_paths + ans[key] = best_paths return ans From cfa9f721743986fab1b77d2d865fad19bf815bfe Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 12 Jul 2021 17:35:48 +0800 Subject: [PATCH 6/9] log_semiring=False and remove repeat tokens --- .../simple_v1/bpe_ctc_att_conformer_decode.py | 165 +++++++++--------- snowfall/decoding/lm_rescore.py | 60 +++++++ 2 files changed, 141 insertions(+), 84 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index 9e9d710c..4e614662 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -31,96 +31,79 @@ from snowfall.decoding.graph import compile_HLG from snowfall.decoding.lm_rescore import rescore_with_n_best_list from snowfall.decoding.lm_rescore import rescore_with_whole_lattice -from snowfall.decoding.lm_rescore import compute_am_scores, _intersect_device +from snowfall.decoding.lm_rescore import compute_am_scores_and_fm_scores from snowfall.models import AcousticModel from snowfall.models.conformer import Conformer from snowfall.text.numericalizer import Numericalizer from snowfall.training.ctc_graph import build_ctc_topo from snowfall.training.mmi_graph import get_phone_symbols - -def extract_nbest_list(lats: k2.Fsa, num_paths: int): +def remove_repeated_and_leq(tokens: List[int], blank_id: int = 0): + ''' + Genrate valid token sequence. + Result may be used as input of transformer decoder and neural language model. + Fristly, remove repeated token from a "token alignment" seqs; + Then remove blank symbols. + + This fuction may be replaced by tokenizing word_seqs with tokenizer + or composeing word_seqs_fsas with L_inv.fst + or composing token_seqs with ctc_topo. + Current method is chosed other than previous three methods because it won't need an extra object, i.e. tokenizer, L.fst or ctc_topo. + ''' + new_tokens = [] + previous = None + for token in tokens: + if token != previous: + new_tokens.append(token) + previous = token + new_tokens = [token for token in new_tokens if token > blank_id] + return new_tokens + +def nbest_am_flm_scrores(lats: k2.Fsa, num_paths: int): + ''' + Compute am scores with word_seqs + ''' # lats has token IDs as labels # and word IDs as aux_labels. # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) - # Both token_seqs and word_seqs are k2.RaggedInt sharing the same shape as `paths` - # Note that they also contain 0s and -1s. + # word_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - token_seqs = k2.index(lats.labels.contiguous(), paths) - word_seqs = k2.index(lats.aux_labels.contiguous(), paths) - token_seqs = k2.ragged.remove_values_leq(token_seqs, 0) + word_seqs = k2.index(lats.aux_labels.contiguous(), paths) word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) - return token_seqs, word_seqs -def compute_am_flm_scrores_1(lats, word_seqs, token_seqs): - ''' - Compute am scores with word_seqs - wer is worse than compute_am_flm_scores_2 - ''' # lats has token IDs as labels and word IDs as aux_labels. unique_word_seqs, _, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=False, need_new2old_indexes=True) seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) path_to_seq_map = seq_to_path_shape.row_ids(1) + + # used to split final computed tot_scores + seq_to_path_splits = seq_to_path_shape.row_splits(1) + unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) - am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map) - # lats has token IDs as labels and word IDs as aux_labels. - # inv_lats has word IDs as labels and token IDs as aux_labels - inv_lats = k2.arc_sort(k2.invert(lats)) # no-op if inv_lats is already arc-sorted - - fgram_lm_lats = _intersect_device(inv_lats, word_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) - fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats)) - - # log_semiring=False is a little better than log_semiring=True. - fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False) - fgram_lm_scores = fgram_tot_scores - am_scores + am_scores, lm_scores = compute_am_scores_and_fm_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map) - # Now token_seqs has only two axes [path][word] + # token_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains token IDs. + # Note that it also contains 0s and -1s. + token_seqs = k2.index(lats.labels.contiguous(), paths) token_seqs = k2.ragged.remove_axis(token_seqs, 0) token_ids, _ = k2.ragged.index(token_seqs, new2old, axis=0) token_ids = k2.ragged.to_list(token_ids) - return am_scores, fgram_lm_scores, token_ids, new2old - -def compute_am_flm_scrores_2(lats, word_seqs, token_seqs): - ''' - Compute am scores with token_seqs - wer is better than compute_am_flm_scores_1 - ''' - unique_token_seqs, _, new2old = k2.ragged.unique_sequences( - token_seqs, need_num_repeats=False, need_new2old_indexes=True) - - seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) - path_to_seq_map = seq_to_path_shape.row_ids(1) - - unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) - # token_fsas is an FsaVec with axes [path][state][arc] - token_fsas = k2.linear_fsa(unique_token_seqs) - token_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(token_fsas) - # lats has token IDs as labels and word IDs as aux_labels. - # inv_lats has word IDs as labels and token IDs as aux_labels - am_scores = compute_am_scores(k2.arc_sort(k2.invert(lats)), token_fsas_with_epsilon_loops, path_to_seq_map) - - fgram_lm_lats = _intersect_device(k2.arc_sort(lats), token_fsas_with_epsilon_loops, path_to_seq_map, sorted_match_a=True) - fgram_lm_lats = k2.top_sort(k2.connect(fgram_lm_lats)) + # Now remove repeated tokens and 0s and -1s. + token_ids = [remove_repeated_and_leq(tokens) for tokens in token_ids] - # log_semiring=False is a little better than log_semiring=True. - fgram_tot_scores = fgram_lm_lats.get_tot_scores(use_double_scores=True, log_semiring=False) - fgram_lm_scores = fgram_tot_scores - am_scores + return am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits, word_seqs - # Now token_seqs has only two axes [path][word] - token_seqs = k2.ragged.remove_axis(token_seqs, 0) - token_ids, _ = k2.ragged.index(token_seqs, new2old, axis=0) - token_ids = k2.ragged.to_list(token_ids) - return am_scores, fgram_lm_scores, token_ids, new2old - -def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: int): +def nbest_rescoring(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: int): ''' N-best rescore with transformer-decoder model. The basic idea is to first extra n-best paths from the given lattice. @@ -130,36 +113,51 @@ def nbest_decoding(model, encoder_memory, memory_mask, lats: k2.Fsa, num_paths: Total scores is a weight sum of am_score and decoder_scores. The one with the max total score is used as the decoding output. ''' - # token_seqs, word_seqs, unique_token_seqs, unique_word_seqs = extract_nbest_list(lats, num_paths) - token_seqs, word_seqs = extract_nbest_list(lats, num_paths) + am_scores, fgram_lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits, word_seqs = nbest_am_flm_scrores(lats, num_paths=num_paths) - # am_scores, fgram_lm_scores, token_ids, new2old = compute_am_flm_scrores_1(lats, word_seqs, token_seqs) - am_scores, fgram_lm_scores, token_ids, new2old = compute_am_flm_scrores_2(lats, word_seqs, token_seqs) - # now compute lm scores from transformer decoder + # Start to compute lm scores from transformer decoder. + + # an example of path_to_seq_map is: + # tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + # device='cuda:0', dtype=torch.int32) + path_to_seq_map = torch.tensor(path_to_seq_map).to(lats.device) + seq_to_path_splits = seq_to_path_splits.to('cpu').long() num_seqs = len(token_ids) - time_steps = encoder_memory.shape[0] - feature_dim = encoder_memory.shape[2] - encoder_memory = encoder_memory.expand(time_steps, num_seqs, feature_dim) - memory_mask = memory_mask.expand(num_seqs, time_steps) + # encoder_memory shape: [T, N, C] --> [T, (nbest1 + nbest2 + **), C] + encoder_memory = encoder_memory.index_select(1, path_to_seq_map) + # memory_mask shape: [N, T] --> [(nbest1+nbest2), T] + memory_mask = memory_mask.index_select(0, path_to_seq_map) # nll: negative log-likelihood nll = model.decoder_nll(encoder_memory, memory_mask, token_ids=token_ids) assert nll.shape[0] == num_seqs decoder_scores = - nll.sum(dim=1) - flm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0, 4.0, 6.0, 8.0, 10.0] - - decoder_scale_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1.0, 2.0, 4.0, 6.0, 8.0, 10.0] - decoder_scale_list += [0.01, 0.03, 0.05, 0.08, 0.09] + flm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + decoder_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0, 1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] ans = dict() + word_seqs = k2.ragged.to_list(k2.ragged.remove_axis(word_seqs,0)) for flm_scale in flm_scale_list: for decoder_scale in decoder_scale_list: key = f'lm_scale_{flm_scale}_decoder_scale_{decoder_scale}' - tot_scores = am_scores + flm_scale * fgram_lm_scores + decoder_scale * decoder_scores - best_seq_idx = new2old[torch.argmax(tot_scores)] - best_word_seq = [k2.ragged.to_list(word_seqs)[0][best_seq_idx]] - ans[key] = best_word_seq + batch_tot_scores = am_scores + flm_scale * fgram_lm_scores + decoder_scale * decoder_scores + batch_tot_scores = torch.tensor_split(batch_tot_scores, seq_to_path_splits[1:]) + ans[key] = [] + processed_seqs = 0 + for tot_scores in batch_tot_scores: + if tot_scores.nelement() == 0: + # the last element by torch.tensor_split may be empty + # e.g. + # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4])) + # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64)) + + break + best_seq_idx = new2old[processed_seqs + torch.argmax(tot_scores)] + best_word_seq = word_seqs[best_seq_idx] + processed_seqs += tot_scores.nelement() + ans[key].append(best_word_seq) + assert len(ans[key]) == seq_to_path_splits.nelement() - 1 return ans @@ -218,7 +216,6 @@ def decode_one_batch(batch: Dict[str, Any], assert feature.ndim == 3 feature = feature.to(device) batch_size = feature.shape[0] - assert batch_size == 1, 'Currently only surrort batch_size=1' # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] @@ -249,7 +246,7 @@ def decode_one_batch(batch: Dict[str, Any], fgram_rescored_lattices = rescore_with_whole_lattice(lattices, G, lm_scale_list=None, need_rescored_lats=True) - ans = nbest_decoding(model, encoder_memory, memory_mask, fgram_rescored_lattices, num_paths) + ans = nbest_rescoring(model, encoder_memory, memory_mask, fgram_rescored_lattices, num_paths) return ans @@ -274,6 +271,8 @@ def decode(dataloader: torch.utils.data.DataLoader, # - value: It is a list of tuples (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): texts = batch['supervisions']['text'] + # TODO(Liyong Guo): Something wrong with batch_size > 1, fix this. + assert len(texts) == 1 hyps_dict = decode_one_batch(batch=batch, model=model, @@ -286,13 +285,11 @@ def decode(dataloader: torch.utils.data.DataLoader, for lm_scale, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) - - for i in range(len(texts)): - hyp_words = [symbols.get(x) for x in hyps[i]] - ref_words = texts[i].split(' ') + for hyp, text in zip(hyps, texts): + hyp_words = [symbols.get(x) for x in hyp] + ref_words = text.split(' ') this_batch.append((ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) + results[lm_scale].extend(this_batch) if batch_idx % 10 == 0: logging.info( diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index caa5d01d..0cc5ea6b 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -3,6 +3,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Tuple import math @@ -49,6 +50,65 @@ def _intersect_device(a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor, return k2.cat(ans) +def compute_am_scores_and_fm_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, + path_to_seq_map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + '''Compute AM and LM scores of n-best lists (represented as word_fsas). + + Args: + lats: + An FsaVec, which is the output of `k2.intersect_dense_pruned`. + It must have the attribute `lm_scores`. + word_fsas_with_epsilon_loops: + An FsaVec representing a n-best list. Note that it has been processed + by `k2.add_epsilon_self_loops`. + path_to_seq_map: + A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates + which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to. + path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). + Returns: + Return a tuple of (1-D torch.Tensor, 1-D torch.Tensor) containing the AM and FM scores of each path. + `am_scores.numel() == word_fsas_with_epsilon_loops.shape[0]` + `lm_scores.numel() == word_fsas_with_epsilon_loops.shape[0]` + ''' + device = lats.device + assert len(lats.shape) == 3 + assert hasattr(lats, 'lm_scores') + + # k2.compose() currently does not support b_to_a_map. To void + # replicating `lats`, we use k2.intersect_device here. + # + # lats has phone IDs as `labels` and word IDs as aux_labels, so we + # need to invert it here. + inverted_lats = k2.invert(lats) + + # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor) + # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes) + + # Remove its `aux_labels` since it is not needed in the + # following computation + del inverted_lats.aux_labels + inverted_lats = k2.arc_sort(inverted_lats) + + am_path_lats = _intersect_device(inverted_lats, + word_fsas_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True) + + am_path_lats = k2.top_sort(k2.connect(am_path_lats)) + + # The `scores` of every arc consists of `am_scores` and `lm_scores` + am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores + + # am_scores = am_path_lats.get_tot_scores(True, True) # wer: 2.77 + am_scores = am_path_lats.get_tot_scores(use_double_scores=True, log_semiring=False) # wer 2.73 + + # Start to compute lm_scores + am_path_lats.scores = am_path_lats.lm_scores + + # fm_scores = am_path_lats.get_tot_scores(True, True) # wer: 2.77 + lm_scores = am_path_lats.get_tot_scores(use_double_scores=True, log_semiring=False) # wer 2.73 + + return am_scores, lm_scores def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, path_to_seq_map: torch.Tensor) -> torch.Tensor: From 918152436f131951fa5fab5c898e4497ba9d41e0 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Tue, 13 Jul 2021 17:36:27 +0800 Subject: [PATCH 7/9] support batch decoding --- .../asr/simple_v1/bpe_ctc_att_conformer_decode.py | 9 ++++++--- egs/librispeech/asr/simple_v1/bpe_run.sh | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index 4e614662..d8cc44f8 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -236,6 +236,10 @@ def decode_one_batch(batch: Dict[str, Any], supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) + # cuts has been sorted in lhotse dataset + # https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L109 + assert torch.all(torch.argsort(indices) == indices) + supervision_segments = supervision_segments[indices] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) @@ -271,8 +275,6 @@ def decode(dataloader: torch.utils.data.DataLoader, # - value: It is a list of tuples (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): texts = batch['supervisions']['text'] - # TODO(Liyong Guo): Something wrong with batch_size > 1, fix this. - assert len(texts) == 1 hyps_dict = decode_one_batch(batch=batch, model=model, @@ -289,7 +291,8 @@ def decode(dataloader: torch.utils.data.DataLoader, hyp_words = [symbols.get(x) for x in hyp] ref_words = text.split(' ') this_batch.append((ref_words, hyp_words)) - results[lm_scale].extend(this_batch) + + results[lm_scale].extend(this_batch) if batch_idx % 10 == 0: logging.info( diff --git a/egs/librispeech/asr/simple_v1/bpe_run.sh b/egs/librispeech/asr/simple_v1/bpe_run.sh index b1c850ed..fd3bcfcc 100644 --- a/egs/librispeech/asr/simple_v1/bpe_run.sh +++ b/egs/librispeech/asr/simple_v1/bpe_run.sh @@ -103,9 +103,8 @@ fi if [ $stage -le 3 ]; then export CUDA_VISIBLE_DEVICES=2 - # Set max-duration=1 because rescore with decoder only support batch_size=1 python bpe_ctc_att_conformer_decode.py \ - --max-duration=1 \ + --max-duration=20 \ --generate-release-model=False \ --decode_with_released_model=True \ --num-paths-for-decoder-rescore=500 From fd93a50e18a2be17de6216951f8dc54482042ca8 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Tue, 13 Jul 2021 20:00:51 +0800 Subject: [PATCH 8/9] decode test-other --- egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index d8cc44f8..88c44824 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -554,8 +554,7 @@ def main(): HLG.lm_scores = HLG.scores.clone() librispeech = LibriSpeechAsrDataModule(args) - # test_sets = ['test-clean', 'test-other'] - test_sets = ['test-clean'] + test_sets = ['test-clean', 'test-other'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): logging.info(f'* DECODING: {test_set}') From d51d32a64eba513534304dccd157ae7dce2aefb6 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 21 Jul 2021 18:41:39 +0800 Subject: [PATCH 9/9] use batch_norm as MVN --- .../asr/simple_v1/bpe_ctc_att_conformer_decode.py | 9 ++++++--- egs/librispeech/asr/simple_v1/bpe_run.sh | 7 +++++-- snowfall/models/conformer.py | 4 ++-- snowfall/models/transformer.py | 8 +++++++- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py index 88c44824..7845444b 100755 --- a/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py +++ b/egs/librispeech/asr/simple_v1/bpe_ctc_att_conformer_decode.py @@ -450,14 +450,17 @@ def main(): num_decoder_layers=num_decoder_layers, vgg_frontend=args.vgg_frontend, is_espnet_structure=args.is_espnet_structure, - mmi_loss=False) + mmi_loss=False, + use_feat_batchnorm=True) if args.espnet_identical_model: - assert sum([p.numel() for p in model.parameters()]) == 116146960 + # + 160 for feat_batch_norm, used as feature mean and variance normalization + assert sum([p.numel() for p in model.parameters()]) == 116146960 + 160 else: raise NotImplementedError("Model of type " + str(model_type) + " is not verified") - exp_dir = Path(f'exp-bpe-lrfactor{args.lr_factor}-{model_type}-{attention_dim}-{nhead}-noam/') + exp_dir = Path(f'exp-duration-200-feat_batchnorm-bpe-lrfactor{args.lr_factor}-{model_type}-{attention_dim}-{nhead}-noam/') + if args.decode_with_released_model is True: released_model_path = exp_dir / f'model-epoch-{epoch}-avg-{avg}.pt' model.load_state_dict(torch.load(released_model_path)) diff --git a/egs/librispeech/asr/simple_v1/bpe_run.sh b/egs/librispeech/asr/simple_v1/bpe_run.sh index fd3bcfcc..7c4cc2d4 100644 --- a/egs/librispeech/asr/simple_v1/bpe_run.sh +++ b/egs/librispeech/asr/simple_v1/bpe_run.sh @@ -17,7 +17,7 @@ if [ $download_model -eq 1 ]; then exit 0 fi git clone https://huggingface.co/GuoLiyong/snowfall_bpe_model - for sub_dir in data exp-bpe-lrfactor10.0-conformer-512-8-noam; do + for sub_dir in data exp-duration-200-feat_batchnorm-bpe-lrfactor5.0-conformer-512-8-noam; do ln -sf snowfall_bpe_model/$sub_dir ./ done fi @@ -104,7 +104,10 @@ fi if [ $stage -le 3 ]; then export CUDA_VISIBLE_DEVICES=2 python bpe_ctc_att_conformer_decode.py \ - --max-duration=20 \ + --epoch 51 \ + --avg 20 \ + --lr-factor 5.0 \ + --max-duration=10 \ --generate-release-model=False \ --decode_with_released_model=True \ --num-paths-for-decoder-rescore=500 diff --git a/snowfall/models/conformer.py b/snowfall/models/conformer.py index 690ae753..0a72450d 100644 --- a/snowfall/models/conformer.py +++ b/snowfall/models/conformer.py @@ -35,12 +35,12 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int num_encoder_layers: int = 12, num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, - is_espnet_structure: bool = False, mmi_loss: bool = True) -> None: + is_espnet_structure: bool = False, mmi_loss: bool = True, use_feat_batchnorm: bool = False) -> None: super(Conformer, self).__init__(num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor, d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend, - mmi_loss=mmi_loss) + mmi_loss=mmi_loss, use_feat_batchnorm=use_feat_batchnorm) self.encoder_pos = RelPositionalEncoding(d_model, dropout) diff --git a/snowfall/models/transformer.py b/snowfall/models/transformer.py index a15dd70d..63807581 100644 --- a/snowfall/models/transformer.py +++ b/snowfall/models/transformer.py @@ -37,8 +37,12 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, num_decoder_layers: int = 6, dropout: float = 0.1, normalize_before: bool = True, - vgg_frontend: bool = False, mmi_loss: bool = True) -> None: + vgg_frontend: bool = False, mmi_loss: bool = True, use_feat_batchnorm:bool = False) -> None: super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + if use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + self.num_features = num_features self.num_classes = num_classes self.subsampling_factor = subsampling_factor @@ -99,6 +103,8 @@ def forward(self, x: Tensor, supervision: Optional[Supervisions] = None) -> Tupl Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. """ + if self.use_feat_batchnorm: + x = self.feat_batchnorm(x) encoder_memory, memory_mask = self.encode(x, supervision) x = self.encoder_output(encoder_memory) return x, encoder_memory, memory_mask