From b1ee46f731f3522c4028879bc726a72af24b16da Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 21 Jun 2021 16:14:56 +0800 Subject: [PATCH 1/3] Support to choose unigram and bigram for P in LF-MMI training. --- .../simple_v1/mmi_att_transformer_decode.py | 22 +++++++++++++-- .../simple_v1/mmi_att_transformer_train.py | 23 +++++++++++++-- snowfall/training/mmi_graph.py | 28 ++++++++++++++++++- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py index 76c7cc08..000acff6 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py @@ -71,6 +71,7 @@ from snowfall.models.contextnet import ContextNet from snowfall.training.ctc_graph import build_ctc_topo from snowfall.training.mmi_graph import create_bigram_phone_lm +from snowfall.training.mmi_graph import create_unigram_phone_lm from snowfall.training.mmi_graph import get_phone_symbols def nbest_decoding(lats: k2.Fsa, num_paths: int): @@ -401,6 +402,15 @@ def get_parser(): type=str2bool, default=True, help='When enabled, it uses vgg style network for subsampling') + + parser.add_argument( + '--use-unigram-lm', + type=str2bool, + default=False, + help='True to use unigram LM for P. False to use bigram LM for P. ' + 'This is used only for checkpoint-loading'. + ) + return parser @@ -423,7 +433,10 @@ def main(): output_beam_size = args.output_beam_size - exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') + if args.use_unigram_lm: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer-unigram') + else: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') logging.info(f'output_beam_size: {output_beam_size}') @@ -434,7 +447,12 @@ def main(): phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') phone_ids = get_phone_symbols(phone_symbol_table) - P = create_bigram_phone_lm(phone_ids) + if args.use_unigram_lm: + logging.info('Use unigram LM for P') + P = create_unigram_phone_lm(phone_ids) + else: + logging.info('Use bigram LM for P') + P = create_bigram_phone_lm(phone_ids) phone_ids_with_blank = [0] + phone_ids ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py index f9526481..29ebdcb3 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py @@ -44,6 +44,7 @@ from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change from snowfall.training.mmi_graph import MmiTrainingGraphCompiler from snowfall.training.mmi_graph import create_bigram_phone_lm +from snowfall.training.mmi_graph import create_unigram_phone_lm def get_objf(batch: Dict, @@ -461,6 +462,14 @@ def get_parser(): 'so that they can be simply loaded with torch.jit.load(). ' '-1 disables this option.' ) + + parser.add_argument( + '--use-unigram-lm', + type=str2bool, + default=False, + help='True to use unigram LM for P. False to use bigram LM for P.' + ) + return parser @@ -487,7 +496,10 @@ def run(rank, world_size, args): fix_random_seed(42) setup_dist(rank, world_size, args.master_port) - exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') + if args.use_unigram_lm: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer-unigram') + else: + exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer') setup_logger(f'{exp_dir}/log/log-train-{rank}') if args.tensorboard and rank == 0: tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') @@ -507,7 +519,14 @@ def run(rank, world_size, args): device=device, ) phone_ids = lexicon.phone_symbols() - P = create_bigram_phone_lm(phone_ids) + + if args.use_unigram_lm: + logging.info('Use unigram LM for P') + P = create_bigram_phone_lm(phone_ids) + else: + logging.info('Use bigram LM for P') + P = create_bigram_phone_lm(phone_ids) + P.scores = torch.zeros_like(P.scores) P = P.to(device) diff --git a/snowfall/training/mmi_graph.py b/snowfall/training/mmi_graph.py index 758830f7..c970bbb3 100644 --- a/snowfall/training/mmi_graph.py +++ b/snowfall/training/mmi_graph.py @@ -38,7 +38,33 @@ def create_bigram_phone_lm(phones: List[int]) -> k2.Fsa: rules += f'{i} {j} {phones[j-1]} 0.0\n' rules += f'{i} {final_state} -1 0.0\n' rules += f'{final_state}' - return k2.Fsa.from_str(rules) + ans = k2.Fsa.from_str(rules) + return k2.arc_sort(ans) + +def create_unigram_phone_lm(phones: List[int]) -> k2.Fsa: + '''Create a unigram phone LM. + The resulting FSA (P) has two states: a start state and a + final state. For each phone, there is a corresponding self-loop + at the start state. + + Caution: + blank is not a phone. + + Args: + A list of phone IDs. + + Returns: + An FSA representing the unigram phone LM. + ''' + assert 0 not in phones + + rules = '0 1 -1 0.0\n' + for i in phones: + rules += f'0 0 {i} 0.0\n' + rules += '1\n' + + ans = k2.Fsa.from_str(rules) + return k2.arc_sort(ans) class MmiTrainingGraphCompiler(object): From d6c486bfed0e86788aeceba53f458b0ce9baaa64 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 21 Jun 2021 16:36:40 +0800 Subject: [PATCH 2/3] Fix a typo. --- egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py index 29ebdcb3..3601843d 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py @@ -522,7 +522,7 @@ def run(rank, world_size, args): if args.use_unigram_lm: logging.info('Use unigram LM for P') - P = create_bigram_phone_lm(phone_ids) + P = create_unigram_phone_lm(phone_ids) else: logging.info('Use bigram LM for P') P = create_bigram_phone_lm(phone_ids) From 45898f70e3982a47a6fb831c6bb46d5acec9aaf2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 22 Jun 2021 09:47:49 +0800 Subject: [PATCH 3/3] Fix a typo. --- egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py index 000acff6..86b05c74 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py @@ -408,7 +408,7 @@ def get_parser(): type=str2bool, default=False, help='True to use unigram LM for P. False to use bigram LM for P. ' - 'This is used only for checkpoint-loading'. + 'This is used only for checkpoint-loading.' ) return parser