From 0219ad40fcc33a0082a229376c006e696c301907 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Fri, 27 Aug 2021 13:56:49 +0800 Subject: [PATCH 01/10] Update mmi_bigram_train.py --- egs/aishell/asr/simple_v1/mmi_bigram_train.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/egs/aishell/asr/simple_v1/mmi_bigram_train.py b/egs/aishell/asr/simple_v1/mmi_bigram_train.py index 2ff9342d..c24b98a5 100644 --- a/egs/aishell/asr/simple_v1/mmi_bigram_train.py +++ b/egs/aishell/asr/simple_v1/mmi_bigram_train.py @@ -131,10 +131,6 @@ def maybe_log_gradients(tag: str): optimizer.zero_grad() (-mmi_loss).backward() - for name, param in model.named_parameters(): - if param.grad is None: - print(name) - maybe_log_gradients('train/grad_norms') #clip_grad_value_(model.parameters(), 5.0) clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) From 2d92b5339f18927f6540f8bc44d902aab7192558 Mon Sep 17 00:00:00 2001 From: Luo Mingshuang <739314837@qq.com> Date: Fri, 3 Sep 2021 11:18:10 +0800 Subject: [PATCH 02/10] Add timit recipe --- egs/timit/asr/simple_v1/RESULTS.md | 20 + egs/timit/asr/simple_v1/ctc_decode.py | 198 +++++++++ egs/timit/asr/simple_v1/ctc_train.py | 418 ++++++++++++++++++ .../asr/simple_v1/local/add_lex_disambig.pl | 196 ++++++++ .../local/add_silence_to_transcript.py | 112 +++++ egs/timit/asr/simple_v1/local/apply_map.pl | 97 ++++ egs/timit/asr/simple_v1/local/arpa2fst.py | 168 +++++++ .../local/convert_transcript_to_corpus.py | 132 ++++++ egs/timit/asr/simple_v1/local/filter_scp.pl | 87 ++++ .../asr/simple_v1/local/fstaddselfloops.pl | 44 ++ egs/timit/asr/simple_v1/local/make_kn_lm.py | 377 ++++++++++++++++ .../asr/simple_v1/local/make_lexicon_fst.py | 411 +++++++++++++++++ .../asr/simple_v1/local/parse_options.sh | 97 ++++ egs/timit/asr/simple_v1/local/prepare_lang.sh | 382 ++++++++++++++++ egs/timit/asr/simple_v1/local/sym2int.pl | 101 +++++ .../asr/simple_v1/local2/phones.60-48-39.map | 61 +++ .../asr/simple_v1/local2/timit_data_prep.sh | 86 ++++ .../asr/simple_v1/local2/timit_norm_trans.pl | 92 ++++ .../simple_v1/local2/timit_prepare_dict.sh | 62 +++ egs/timit/asr/simple_v1/local2/train_lms.sh | 92 ++++ egs/timit/asr/simple_v1/path.sh | 3 + egs/timit/asr/simple_v1/prepare.py | 155 +++++++ egs/timit/asr/simple_v1/run.sh | 105 +++++ 23 files changed, 3496 insertions(+) create mode 100644 egs/timit/asr/simple_v1/RESULTS.md create mode 100644 egs/timit/asr/simple_v1/ctc_decode.py create mode 100644 egs/timit/asr/simple_v1/ctc_train.py create mode 100644 egs/timit/asr/simple_v1/local/add_lex_disambig.pl create mode 100644 egs/timit/asr/simple_v1/local/add_silence_to_transcript.py create mode 100644 egs/timit/asr/simple_v1/local/apply_map.pl create mode 100644 egs/timit/asr/simple_v1/local/arpa2fst.py create mode 100644 egs/timit/asr/simple_v1/local/convert_transcript_to_corpus.py create mode 100644 egs/timit/asr/simple_v1/local/filter_scp.pl create mode 100644 egs/timit/asr/simple_v1/local/fstaddselfloops.pl create mode 100644 egs/timit/asr/simple_v1/local/make_kn_lm.py create mode 100644 egs/timit/asr/simple_v1/local/make_lexicon_fst.py create mode 100644 egs/timit/asr/simple_v1/local/parse_options.sh create mode 100644 egs/timit/asr/simple_v1/local/prepare_lang.sh create mode 100644 egs/timit/asr/simple_v1/local/sym2int.pl create mode 100644 egs/timit/asr/simple_v1/local2/phones.60-48-39.map create mode 100644 egs/timit/asr/simple_v1/local2/timit_data_prep.sh create mode 100644 egs/timit/asr/simple_v1/local2/timit_norm_trans.pl create mode 100644 egs/timit/asr/simple_v1/local2/timit_prepare_dict.sh create mode 100644 egs/timit/asr/simple_v1/local2/train_lms.sh create mode 100644 egs/timit/asr/simple_v1/path.sh create mode 100644 egs/timit/asr/simple_v1/prepare.py create mode 100644 egs/timit/asr/simple_v1/run.sh diff --git a/egs/timit/asr/simple_v1/RESULTS.md b/egs/timit/asr/simple_v1/RESULTS.md new file mode 100644 index 00000000..95694bc9 --- /dev/null +++ b/egs/timit/asr/simple_v1/RESULTS.md @@ -0,0 +1,20 @@ +# TIMIT CTC Training Results + +## 2021-09-03 +(Mingshuang Luo): + +### TIMIT CTC_Train +Testing results based on different training epochs: +``` +epoch=20 +2021-09-03 10:54:10,903 INFO [ctc_decode.py:188] %PER 30.34% [2225 / 7333, 293 ins, 441 del, 1491 sub ] + +epoch=30 +2021-09-03 10:59:10,147 INFO [ctc_decode.py:188] %PER 29.77% [2183 / 7333, 221 ins, 473 del, 1489 sub ] + +epoch=35 +2021-09-03 11:11:00,885 INFO [ctc_decode.py:188] %PER 28.94% [2122 / 7333, 266 ins, 397 del, 1459 sub ] + +epoch=40 +2021-09-03 11:12:39,029 INFO [ctc_decode.py:188] %PER 29.52% [2165 / 7333, 304 ins, 348 del, 1513 sub ] +``` diff --git a/egs/timit/asr/simple_v1/ctc_decode.py b/egs/timit/asr/simple_v1/ctc_decode.py new file mode 100644 index 00000000..aff8e13e --- /dev/null +++ b/egs/timit/asr/simple_v1/ctc_decode.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu) +# Apache 2.0 + +import k2 +import logging +import os +import torch +from k2 import Fsa, SymbolTable +from kaldialign import edit_distance +from pathlib import Path +from typing import Union + +from lhotse import CutSet +from lhotse.dataset import K2SpeechRecognitionDataset +from lhotse.dataset import SingleCutSampler +from snowfall.common import find_first_disambig_symbol +from snowfall.common import get_phone_symbols +from snowfall.common import get_texts +from snowfall.common import load_checkpoint +from snowfall.common import setup_logger +from snowfall.decoding.graph import compile_HLG +from snowfall.models import AcousticModel +from snowfall.models.tdnn_lstm import TdnnLstm1b +from snowfall.training.ctc_graph import build_ctc_topo + +import sys +import argparse + +def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, + device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable): + tot_num_cuts = len(dataloader) + num_cuts = 0 + results = [] # a list of pair (ref_words, hyp_words) + for batch_idx, batch in enumerate(dataloader): + + feature = batch['inputs'] + supervisions = batch['supervisions'] + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + torch.floor_divide(supervisions['start_frame'], + model.subsampling_factor), + torch.floor_divide(supervisions['num_frames'], + model.subsampling_factor)), 1).to(torch.int32) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + texts = supervisions['text'] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + with torch.no_grad(): + nnet_output = model(feature) + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, + 1) # now nnet_output is [N, T, C] + + blank_bias = -3.0 + nnet_output[:, :, 0] += blank_bias + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + # assert HLG.is_cuda() + assert HLG.device == nnet_output.device, \ + f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" + # TODO(haowen): with a small `beam`, we may get empty `target_graph`, + # thus `tot_scores` will be `inf`. Definitely we need to handle this later. + lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) + + # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0) + best_paths = k2.shortest_path(lattices, use_double_scores=True) + assert best_paths.shape[0] == len(texts) + hyps = get_texts(best_paths, indices) + assert len(hyps) == len(texts) + + for i in range(len(texts)): + hyp_words = [symbols.get(x) for x in hyps[i]] + ref_words = texts[i].split(' ') + results.append((ref_words, hyp_words)) + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( + batch_idx, num_cuts, tot_num_cuts, + float(num_cuts) / tot_num_cuts * 100)) + + num_cuts += len(texts) + + return results + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('--epoch', type=int, default=9, + help='the checkpoint for loading.') + + parser.add_argument('--mode', type=str, default='TEST', + help='the mode to test.') + + args = parser.parse_args() + + exp_dir = Path('exp-lstm-adam-ctc-musan') + setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') + + # load L, G, symbol_table + lang_dir = Path('data/lang_nosp') + symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + phone_ids = get_phone_symbols(phone_symbol_table) + phone_ids_with_blank = [0] + phone_ids + ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) + + if not os.path.exists(lang_dir / 'HLG.pt'): + print("Loading L_disambig.fst.txt") + with open(lang_dir / 'L_disambig.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + print("Loading G.fst.txt") + with open(lang_dir / 'G.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + HLG = compile_HLG(L=L, + G=G, + H=ctc_topo, + labels_disambig_id_start=first_phone_disambig_id, + aux_labels_disambig_id_start=first_word_disambig_id) + torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') + else: + print("Loading pre-compiled HLG") + d = torch.load(lang_dir / 'HLG.pt') + HLG = k2.Fsa.from_dict(d) + + # load dataset + feature_dir = Path('exp/data') + print("About to get test cuts") + cuts_test = CutSet.from_json(feature_dir / 'cuts_{}.json.gz'.format(args.mode)) + + print("About to create test dataset") + test = K2SpeechRecognitionDataset(cuts_test) + sampler = SingleCutSampler(cuts_test, max_frames=100000) + print("About to create test dataloader") + test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) + + # if not torch.cuda.is_available(): + # logging.error('No GPU detected!') + # sys.exit(-1) + + print("About to load model") + # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N + # device = torch.device('cuda', 1) + device = torch.device('cuda') + model = TdnnLstm1b( + num_features=80, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=2) + + checkpoint = os.path.join(exp_dir, 'epoch-{}.pt'.format(args.epoch)) + load_checkpoint(checkpoint, model) + model.to(device) + model.eval() + + print("convert HLG to device") + HLG = HLG.to(device) + HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) + HLG.requires_grad_(False) + print("About to decode") + results = decode(dataloader=test_dl, + model=model, + device=device, + HLG=HLG, + symbols=symbol_table) + s = '' + for ref, hyp in results: + s += f'ref={ref}\n' + s += f'hyp={hyp}\n' + # logging.info(s) + # compute PER + dists = [edit_distance(r, h) for r, h in results] + errors = { + key: sum(dist[key] for dist in dists) + for key in ['sub', 'ins', 'del', 'total'] + } + total_words = sum(len(ref) for ref, _ in results) + + logging.info( + f'%PER {errors["total"] / total_words:.2%} ' + f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() diff --git a/egs/timit/asr/simple_v1/ctc_train.py b/egs/timit/asr/simple_v1/ctc_train.py new file mode 100644 index 00000000..ccf47b95 --- /dev/null +++ b/egs/timit/asr/simple_v1/ctc_train.py @@ -0,0 +1,418 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu) +# Apache 2.0 + +import k2 +import logging +import math +import numpy as np +import os +import sys +import torch +import torch.optim as optim +from datetime import datetime +from pathlib import Path +from torch.nn.utils import clip_grad_value_ +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, Optional, Tuple + +from lhotse import CutSet +from lhotse.dataset import CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.utils import fix_random_seed +from snowfall.common import describe +from snowfall.common import get_phone_symbols +from snowfall.common import load_checkpoint, save_checkpoint +from snowfall.common import save_training_info +from snowfall.common import setup_logger +from snowfall.models import AcousticModel +from snowfall.models.tdnn_lstm import TdnnLstm1b +from snowfall.training.ctc_graph import CtcTrainingGraphCompiler + + +def get_tot_objf_and_num_frames(tot_scores: torch.Tensor, + frames_per_seq: torch.Tensor + ) -> Tuple[float, int, int]: + ''' Figures out the total score(log-prob) over all successful supervision segments + (i.e. those for which the total score wasn't -infinity), and the corresponding + number of frames of neural net output + Args: + tot_scores: a Torch tensor of shape (num_segments,) containing total scores + from forward-backward + frames_per_seq: a Torch tensor of shape (num_segments,) containing the number of + frames for each segment + Returns: + Returns a tuple of 3 scalar tensors: (tot_score, ok_frames, all_frames) + where ok_frames is the frames for successful (finite) segments, and + all_frames is the frames for all segments (finite or not). + ''' + mask = torch.ne(tot_scores, -math.inf) + # finite_indexes is a tensor containing successful segment indexes, e.g. + # [ 0 1 3 4 5 ] + finite_indexes = torch.nonzero(mask).squeeze(1) + if False: + bad_indexes = torch.nonzero(~mask).squeeze(1) + if bad_indexes.shape[0] > 0: + print("Bad indexes: ", bad_indexes, ", bad lengths: ", + frames_per_seq[bad_indexes], " vs. max length ", + torch.max(frames_per_seq), ", avg ", + (torch.sum(frames_per_seq) / frames_per_seq.numel())) + # print("finite_indexes = ", finite_indexes, ", tot_scores = ", tot_scores) + ok_frames = frames_per_seq[finite_indexes].sum() + all_frames = frames_per_seq.sum() + return (tot_scores[finite_indexes].sum(), ok_frames, all_frames) + + +def get_objf(batch: Dict, + model: AcousticModel, + device: torch.device, + graph_compiler: CtcTrainingGraphCompiler, + training: bool, + optimizer: Optional[torch.optim.Optimizer] = None): + feature = batch['inputs'] + supervisions = batch['supervisions'] + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + torch.floor_divide(supervisions['start_frame'], + model.subsampling_factor), + torch.floor_divide(supervisions['num_frames'], + model.subsampling_factor)), 1).to(torch.int32) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + + texts = supervisions['text'] + texts = [texts[idx] for idx in indices] + assert feature.ndim == 3 + # print(supervision_segments[:, 1] + supervision_segments[:, 2]) + + feature = feature.to(device) + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + if training: + nnet_output = model(feature) + else: + with torch.no_grad(): + nnet_output = model(feature) + + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] + #print(nnet_output.size()) + decoding_graph = graph_compiler.compile(texts).to(device) + + # nnet_output2 = nnet_output.clone() + # blank_bias = -7.0 + # nnet_output2[:,:,0] += blank_bias + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + assert decoding_graph.is_cuda() + assert decoding_graph.device == device + assert nnet_output.device == device + + target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) + + tot_scores = target_graph.get_tot_scores( + log_semiring=True, + use_double_scores=True) + + (tot_score, tot_frames, + all_frames) = get_tot_objf_and_num_frames(tot_scores, + supervision_segments[:, 2]) + + if training: + optimizer.zero_grad() + (-tot_score).backward() + clip_grad_value_(model.parameters(), 5.0) + optimizer.step() + + ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( + ), all_frames.cpu().item() + return ans + + +def get_validation_objf(dataloader: torch.utils.data.DataLoader, + model: AcousticModel, device: torch.device, + graph_compiler: CtcTrainingGraphCompiler): + total_objf = 0. + total_frames = 0. # for display only + total_all_frames = 0. # all frames including those seqs that failed. + + model.eval() + + for batch_idx, batch in enumerate(dataloader): + objf, frames, all_frames = get_objf(batch, model, device, + graph_compiler, False) + total_objf += objf + total_frames += frames + total_all_frames += all_frames + + return total_objf, total_frames, total_all_frames + + +def train_one_epoch(dataloader: torch.utils.data.DataLoader, + valid_dataloader: torch.utils.data.DataLoader, + model: AcousticModel, device: torch.device, + graph_compiler: CtcTrainingGraphCompiler, + optimizer: torch.optim.Optimizer, + current_epoch: int, + tb_writer: SummaryWriter, + num_epochs: int, + global_batch_idx_train: int): + total_objf, total_frames, total_all_frames = 0., 0., 0. + valid_average_objf = float('inf') + time_waiting_for_batch = 0 + prev_timestamp = datetime.now() + + model.train() + for batch_idx, batch in enumerate(dataloader): + global_batch_idx_train += 1 + timestamp = datetime.now() + time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() + curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \ + get_objf(batch, model, device, graph_compiler, True, optimizer) + + total_objf += curr_batch_objf + total_frames += curr_batch_frames + total_all_frames += curr_batch_all_frames + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, epoch {}/{} ' + 'global average objf: {:.6f} over {} ' + 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' + 'avg time waiting for batch {:.3f}s'.format( + batch_idx, current_epoch, num_epochs, + total_objf / total_frames, total_frames, + 100.0 * total_frames / total_all_frames, + curr_batch_objf / (curr_batch_frames + 0.001), + curr_batch_frames, + 100.0 * curr_batch_frames / curr_batch_all_frames, + time_waiting_for_batch / max(1, batch_idx))) + + tb_writer.add_scalar('train/global_average_objf', + total_objf / total_frames, global_batch_idx_train) + + tb_writer.add_scalar('train/current_batch_average_objf', + curr_batch_objf / (curr_batch_frames + 0.001), + global_batch_idx_train) + # if batch_idx >= 10: + # print("Exiting early to get profile info") + # sys.exit(0) + + if batch_idx > 0 and batch_idx % 10 == 0: + total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( + dataloader=valid_dataloader, + model=model, + device=device, + graph_compiler=graph_compiler) + valid_average_objf = total_valid_objf / total_valid_frames + model.train() + logging.info( + 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' + .format(valid_average_objf, + total_valid_frames, + 100.0 * total_valid_frames / total_valid_all_frames)) + + tb_writer.add_scalar('train/global_valid_average_objf', + valid_average_objf, + global_batch_idx_train) + prev_timestamp = datetime.now() + return total_objf / total_frames, valid_average_objf, global_batch_idx_train + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('--epochs', type=int, default=40, help='the number of epoch for training.') + + parser.add_argument('--duration', type=int, default=200, help='the max duration in a batch for training.') + + args = parser.parse_args() + + fix_random_seed(42) + + start_epoch = 0 + num_epochs = args.epochs + + exp_dir = 'exp-lstm-adam-ctc-musan' + setup_logger('{}/log/log-train'.format(exp_dir)) + tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') + + # load L, G, symbol_table + lang_dir = Path('data/lang_nosp') + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') + + logging.info("Loading L.fst") + if (lang_dir / 'Linv.pt').exists(): + L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) + else: + with open(lang_dir / 'L.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + L_inv = k2.arc_sort(L.invert_()) + torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') + + graph_compiler = CtcTrainingGraphCompiler( + L_inv=L_inv, + phones=phone_symbol_table, + words=word_symbol_table + ) + phone_ids = get_phone_symbols(phone_symbol_table) + + # load dataset + feature_dir = Path('exp/data') + logging.info("About to get train cuts") + cuts_train = CutSet.from_json(feature_dir / + 'cuts_TRAIN.json.gz') + logging.info("About to get dev cuts") + cuts_dev = CutSet.from_json(feature_dir / 'cuts_TEST.json.gz') + logging.info("About to get Musan cuts") + cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cuts_train, + cut_transforms=[ + CutConcatenate(), + CutMix( + cuts=cuts_musan, + prob=0.5, + snr=(10, 20) + ) + ] + ) + train_sampler = SingleCutSampler( + cuts_train, + #max_frames=180000, + max_duration=args.duration, + shuffle=True, + ) + logging.info("About to create train dataloader") + train_dl = torch.utils.data.DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=4 + ) + logging.info("About to create dev dataset") + validate = K2SpeechRecognitionDataset(cuts_dev) + valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000) + logging.info("About to create dev dataloader") + valid_dl = torch.utils.data.DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=1 + ) + + if not torch.cuda.is_available(): + logging.error('No GPU detected!') + sys.exit(-1) + + logging.info("About to create model") + device_id = 0 + device = torch.device('cuda', device_id) + model = TdnnLstm1b( + num_features=80, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=2) + + model.to(device) + describe(model) + + learning_rate = 2e-3 + optimizer = optim.AdamW(model.parameters(), + lr=learning_rate, + weight_decay=5e-4) + + best_objf = np.inf + best_valid_objf = np.inf + best_epoch = start_epoch + best_model_path = os.path.join(exp_dir, 'best_model.pt') + best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') + global_batch_idx_train = 0 # for logging only + + if start_epoch > 0: + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) + ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) + best_objf = ckpt['objf'] + best_valid_objf = ckpt['valid_objf'] + global_batch_idx_train = ckpt['global_batch_idx_train'] + logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") + + for epoch in range(start_epoch, args.epochs): + train_sampler.set_epoch(epoch) + curr_learning_rate = 1e-3 + # curr_learning_rate = learning_rate * pow(0.4, epoch) + # for param_group in optimizer.param_groups: + # param_group['lr'] = curr_learning_rate + + tb_writer.add_scalar('learning_rate', curr_learning_rate, epoch) + + logging.info('epoch {}, learning rate {}'.format( + epoch, curr_learning_rate)) + objf, valid_objf, global_batch_idx_train = train_one_epoch(dataloader=train_dl, + valid_dataloader=valid_dl, + model=model, + device=device, + graph_compiler=graph_compiler, + optimizer=optimizer, + current_epoch=epoch, + tb_writer=tb_writer, + num_epochs=num_epochs, + global_batch_idx_train=global_batch_idx_train) + # the lower, the better + if valid_objf < best_valid_objf: + best_valid_objf = valid_objf + best_objf = objf + best_epoch = epoch + save_checkpoint(filename=best_model_path, + model=model, + epoch=epoch, + optimizer=None, + scheduler=None, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train) + save_training_info(filename=best_epoch_info_filename, + model_path=best_model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=best_objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch) + + # we always save the model for every epoch + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) + save_checkpoint(filename=model_path, + model=model, + optimizer=optimizer, + scheduler=None, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train) + epoch_info_filename = os.path.join(exp_dir, + 'epoch-{}-info'.format(epoch)) + save_training_info(filename=epoch_info_filename, + model_path=model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch) + + logging.warning('Done') + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() diff --git a/egs/timit/asr/simple_v1/local/add_lex_disambig.pl b/egs/timit/asr/simple_v1/local/add_lex_disambig.pl new file mode 100644 index 00000000..c4277e8d --- /dev/null +++ b/egs/timit/asr/simple_v1/local/add_lex_disambig.pl @@ -0,0 +1,196 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation +# 2013-2016 Johns Hopkins University (author: Daniel Povey) +# 2015 Hainan Xu +# 2015 Guoguo Chen + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Adds disambiguation symbols to a lexicon. +# Outputs still in the normal lexicon format. +# Disambig syms are numbered #1, #2, #3, etc. (#0 +# reserved for symbol in grammar). +# Outputs the number of disambig syms to the standard output. +# With the --pron-probs option, expects the second field +# of each lexicon line to be a pron-prob. +# With the --sil-probs option, expects three additional +# fields after the pron-prob, representing various components +# of the silence probability model. + +$pron_probs = 0; +$sil_probs = 0; +$first_allowed_disambig = 1; + +for ($n = 1; $n <= 3 && @ARGV > 0; $n++) { + if ($ARGV[0] eq "--pron-probs") { + $pron_probs = 1; + shift @ARGV; + } + if ($ARGV[0] eq "--sil-probs") { + $sil_probs = 1; + shift @ARGV; + } + if ($ARGV[0] eq "--first-allowed-disambig") { + $first_allowed_disambig = 0 + $ARGV[1]; + if ($first_allowed_disambig < 1) { + die "add_lex_disambig.pl: invalid --first-allowed-disambig option: $first_allowed_disambig\n"; + } + shift @ARGV; + shift @ARGV; + } +} + +if (@ARGV != 2) { + die "Usage: add_lex_disambig.pl [opts] \n" . + "This script adds disambiguation symbols to a lexicon in order to\n" . + "make decoding graphs determinizable; it adds pseudo-phone\n" . + "disambiguation symbols #1, #2 and so on at the ends of phones\n" . + "to ensure that all pronunciations are different, and that none\n" . + "is a prefix of another.\n" . + "It prints to the standard output the number of the largest-numbered" . + "disambiguation symbol that was used.\n" . + "\n" . + "Options: --pron-probs Expect pronunciation probabilities in the 2nd field\n" . + " --sil-probs [should be with --pron-probs option]\n" . + " Expect 3 extra fields after the pron-probs, for aspects of\n" . + " the silence probability model\n" . + " --first-allowed-disambig The number of the first disambiguation symbol\n" . + " that this script is allowed to add. By default this is\n" . + " #1, but you can set this to a larger value using this option.\n" . + "e.g.:\n" . + " add_lex_disambig.pl lexicon.txt lexicon_disambig.txt\n" . + " add_lex_disambig.pl --pron-probs lexiconp.txt lexiconp_disambig.txt\n" . + " add_lex_disambig.pl --pron-probs --sil-probs lexiconp_silprob.txt lexiconp_silprob_disambig.txt\n"; +} + + +$lexfn = shift @ARGV; +$lexoutfn = shift @ARGV; + +open(L, "<$lexfn") || die "Error opening lexicon $lexfn"; + +# (1) Read in the lexicon. +@L = ( ); +while() { + @A = split(" ", $_); + push @L, join(" ", @A); +} + +# (2) Work out the count of each phone-sequence in the +# lexicon. + +foreach $l (@L) { + @A = split(" ", $l); + shift @A; # Remove word. + if ($pron_probs) { + $p = shift @A; + if (!($p > 0.0 && $p <= 1.0)) { die "Bad lexicon line $l (expecting pron-prob as second field)"; } + } + if ($sil_probs) { + $silp = shift @A; + if (!($silp > 0.0 && $silp <= 1.0)) { die "Bad lexicon line $l for silprobs"; } + $correction = shift @A; + if ($correction <= 0.0) { die "Bad lexicon line $l for silprobs"; } + $correction = shift @A; + if ($correction <= 0.0) { die "Bad lexicon line $l for silprobs"; } + } + if (!(@A)) { + die "Bad lexicon line $1, no phone in phone list"; + } + $count{join(" ",@A)}++; +} + +# (3) For each left sub-sequence of each phone-sequence, note down +# that it exists (for identifying prefixes of longer strings). + +foreach $l (@L) { + @A = split(" ", $l); + shift @A; # Remove word. + if ($pron_probs) { shift @A; } # remove pron-prob. + if ($sil_probs) { + shift @A; # Remove silprob + shift @A; # Remove silprob + shift @A; # Remove silprob, there three numbers for sil_probs + } + while(@A > 0) { + pop @A; # Remove last phone + $issubseq{join(" ",@A)} = 1; + } +} + +# (4) For each entry in the lexicon: +# if the phone sequence is unique and is not a +# prefix of another word, no diambig symbol. +# Else output #1, or #2, #3, ... if the same phone-seq +# has already been assigned a disambig symbol. + + +open(O, ">$lexoutfn") || die "Opening lexicon file $lexoutfn for writing.\n"; + +# max_disambig will always be the highest-numbered disambiguation symbol that +# has been used so far. +$max_disambig = $first_allowed_disambig - 1; + +foreach $l (@L) { + @A = split(" ", $l); + $word = shift @A; + if ($pron_probs) { + $pron_prob = shift @A; + } + if ($sil_probs) { + $sil_word_prob = shift @A; + $word_sil_correction = shift @A; + $prev_nonsil_correction = shift @A + } + $phnseq = join(" ", @A); + if (!defined $issubseq{$phnseq} + && $count{$phnseq} == 1) { + ; # Do nothing. + } else { + if ($phnseq eq "") { # need disambig symbols for the empty string + # that are not use anywhere else. + $max_disambig++; + $reserved_for_the_empty_string{$max_disambig} = 1; + $phnseq = "#$max_disambig"; + } else { + $cur_disambig = $last_used_disambig_symbol_of{$phnseq}; + if (!defined $cur_disambig) { + $cur_disambig = $first_allowed_disambig; + } else { + $cur_disambig++; # Get a number that has not been used yet for + # this phone sequence. + } + while (defined $reserved_for_the_empty_string{$cur_disambig}) { + $cur_disambig++; + } + if ($cur_disambig > $max_disambig) { + $max_disambig = $cur_disambig; + } + $last_used_disambig_symbol_of{$phnseq} = $cur_disambig; + $phnseq = $phnseq . " #" . $cur_disambig; + } + } + if ($pron_probs) { + if ($sil_probs) { + print O "$word\t$pron_prob\t$sil_word_prob\t$word_sil_correction\t$prev_nonsil_correction\t$phnseq\n"; + } else { + print O "$word\t$pron_prob\t$phnseq\n"; + } + } else { + print O "$word\t$phnseq\n"; + } +} + +print $max_disambig . "\n"; diff --git a/egs/timit/asr/simple_v1/local/add_silence_to_transcript.py b/egs/timit/asr/simple_v1/local/add_silence_to_transcript.py new file mode 100644 index 00000000..685bb2f2 --- /dev/null +++ b/egs/timit/asr/simple_v1/local/add_silence_to_transcript.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +''' +Add silence with a given probability after each word in the transcript. + +If the input transcript contains: + + hello world + foo bar koo + zoo + +Then the output transcript **may** look like the following: + + !SIL hello !SIL world !SIL + foo bar !SIL koo !SIL + !SIL zoo !SIL + +(Assume !SIL represents silence.) +''' + +from pathlib import Path + +import argparse +import random + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--transcript', + type=str, + help='The input transcript file.' + 'We assume that the transcript file consists of ' + 'lines. Each line consists of space separated words.') + parser.add_argument('--sil-word', + type=str, + default='!SIL', + help='The word that represents silence.') + parser.add_argument('--sil-prob', + type=float, + default=0.5, + help='The probability for adding a ' + 'silence after each world.') + parser.add_argument('--seed', + type=int, + default=None, + help='The seed for random number generators.') + + return parser.parse_args() + + +def need_silence(sil_prob: float) -> bool: + ''' + Args: + sil_prob: + The probability to add a silence. + Returns: + Return True if a silence is needed. + Return False otherwise. + ''' + return random.uniform(0, 1) <= sil_prob + + +def process_line(line: str, sil_word: str, sil_prob: float) -> None: + '''Process a single line from the transcript. + + Args: + line: + A str containing space separated words. + sil_word: + The symbol indicating silence. + sil_prob: + The probability for adding a silence after each word. + Returns: + Return None. + ''' + + words = line.strip().split(' ')[1:] + + for i, word in enumerate(words): + if i == 0: + # beginning of the line + if need_silence(sil_prob): + print(sil_word, end=' ') + + print(word, end=' ') + + if need_silence(sil_prob): + print(sil_word, end=' ') + + # end of the line, print a new line + if i == len(words) - 1: + print() + + +def main(): + args = get_args() + random.seed(args.seed) + + assert Path(args.transcript).is_file() + assert len(args.sil_word) > 0 + assert 0 < args.sil_prob < 1 + + with open(args.transcript) as f: + for line in f: + process_line(line=line, + sil_word=args.sil_word, + sil_prob=args.sil_prob) + + +if __name__ == '__main__': + main() diff --git a/egs/timit/asr/simple_v1/local/apply_map.pl b/egs/timit/asr/simple_v1/local/apply_map.pl new file mode 100644 index 00000000..725d3463 --- /dev/null +++ b/egs/timit/asr/simple_v1/local/apply_map.pl @@ -0,0 +1,97 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +# Apache 2.0. + +# This program is a bit like ./sym2int.pl in that it applies a map +# to things in a file, but it's a bit more general in that it doesn't +# assume the things being mapped to are single tokens, they could +# be sequences of tokens. See the usage message. + + +$permissive = 0; + +for ($x = 0; $x <= 2; $x++) { + + if (@ARGV > 0 && $ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } + } + + if (@ARGV > 0 && $ARGV[0] eq '--permissive') { + shift @ARGV; + # Mapping is optional (missing key is printed to output) + $permissive = 1; + } +} + +if(@ARGV != 1) { + print STDERR "Invalid usage: " . join(" ", @ARGV) . "\n"; + print STDERR <<'EOF'; +Usage: apply_map.pl [options] map output + options: [-f ] [--permissive] + This applies a map to some specified fields of some input text: + For each line in the map file: the first field is the thing we + map from, and the remaining fields are the sequence we map it to. + The -f (field-range) option says which fields of the input file the map + map should apply to. + If the --permissive option is supplied, fields which are not present + in the map will be left as they were. + Applies the map 'map' to all input text, where each line of the map + is interpreted as a map from the first field to the list of the other fields + Note: can look like 4-5, or 4-, or 5-, or 1, it means the field + range in the input to apply the map to. + e.g.: echo A B | apply_map.pl a.txt + where a.txt is: + A a1 a2 + B b + will produce: + a1 a2 b +EOF + exit(1); +} + +($map_file) = @ARGV; +open(M, "<$map_file") || die "Error opening map file $map_file: $!"; + +while () { + @A = split(" ", $_); + @A >= 1 || die "apply_map.pl: empty line."; + $i = shift @A; + $o = join(" ", @A); + $map{$i} = $o; +} + +while() { + @A = split(" ", $_); + for ($x = 0; $x < @A; $x++) { + if ( (!defined $field_begin || $x >= $field_begin) + && (!defined $field_end || $x <= $field_end)) { + $a = $A[$x]; + if (!defined $map{$a}) { + if (!$permissive) { + die "apply_map.pl: undefined key $a in $map_file\n"; + } else { + print STDERR "apply_map.pl: warning! missing key $a in $map_file\n"; + } + } else { + $A[$x] = $map{$a}; + } + } + } + print join(" ", @A) . "\n"; +} diff --git a/egs/timit/asr/simple_v1/local/arpa2fst.py b/egs/timit/asr/simple_v1/local/arpa2fst.py new file mode 100644 index 00000000..61a4167d --- /dev/null +++ b/egs/timit/asr/simple_v1/local/arpa2fst.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Xiaomi Corporation (Author: Junbo Zhang, Haowen Qiu) +# Apache 2.0 +""" +arpa2fst.py: +Similar to Kaldi's `lmbin/arpa2fst.cc`, rewritten in Python. + +Generally, we do the following. Suppose we are adding an n-gram "A B C". Then find the node for "A B", add a new node +for "A B C", and connect them with the arc accepting "C" with the specified weight. Also, add a backoff arc from the +new "A B C" node to its backoff state "B C". + +Two notable exceptions are the highest order n-grams, and final n-grams. + +When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM), the following optimization is performed. +There is no point adding a node for "A B C" with a "C" arc from "A B", since there will be no other arcs ingoing to +this node, and an epsilon backoff arc into the backoff model "B C", with the weight of \bar{1}. To save a node, +create an arc accepting "C" directly from "A B" to "B C". This saves as many nodes as there are the highest order +n-grams, which is typically about half the size of a large 3-gram model. + +Indeed, this does not apply to n-grams ending in EOS, since they do not back off. These are special, as they do not +have a back-off state, and the node for "(..anything..) " is always final. These are handled in one of the two +possible ways, If symbols and are being replaced by epsilons, neither node nor arc is created, +and the logprob of the n-gram is applied to its source node as final weight. If and are preserved, +then a special final node for is allocated and used as the destination of the "" acceptor arc. """ + +import re +import argparse +from typing import List, NamedTuple + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script is similar to Kaldi's `lmbin/arpa2fst.cc`, + but was rewritten in Python. The output goes to the stdout.""") + parser.add_argument('arpa', type=str, help="""The arpa file.""") + parser.add_argument('--bos', + type=str, + default='', + help="""The begin symbol.""") + parser.add_argument('--eos', + type=str, + default='', + help="""The ending symbol.""") + parser.add_argument('--disambig_symbol', + type=str, + default='#0', + help="""The disambig symbol.""") + parser.add_argument('--to-natural-log', + type=bool, + default=True, + help="""Convert to natural log.""") + args = parser.parse_args() + return args + + +class Ngram(NamedTuple): + logprob: float + prev_words: List + cur_word: str + backoff: float + + +def parse_arpa_file(filename, to_natural_log=True): + ngrams = [] + with open(filename) as f: + stage = 0 # stage 0 for header, 1 for unigram, 2 for bigram, ... + for line in f: + line = re.sub(r'\s+$', '', line) + if re.match(r'^\s*$', line) or re.match( + r'\\(data|end)\\', line) or re.match(r'ngram \d+=', line): + continue + elif re.match(r'\\\d+-grams:', line): + stage = int(re.match(r'\\(\d+)-grams:', line).group(1)) + ngrams.append([]) + assert len(ngrams) == stage + else: + items = re.split(r'\s+', line) + assert stage + 1 <= len( + items) <= stage + 2, f'Invalid arpa line: {line}' + + logprob = float(items[0]) if not to_natural_log else float( + items[0]) * 2.30258509299404568402 + words = items[1:stage + 1] + backoff = float(items[stage + + 1]) if len(items) == stage + 2 else 0.0 + if to_natural_log: + backoff = backoff * 2.30258509299404568402 + + if stage == 1: + prev_words = [] + assert len(words) == 1 + cur_word = words[0] + else: + prev_words = words[0:-1] + cur_word = words[-1] + + ngrams[stage - 1].append( + Ngram(logprob=logprob, + prev_words=prev_words, + cur_word=cur_word, + backoff=backoff)) + + return ngrams + + +def create_backoff(key, state, state_id, weight, sub_eps): + while key not in state_id: + key = ' '.join((key.split(' '))[1:]) + dest = state_id[key] + print(f'{state} {dest} {sub_eps} {weight}') + + +# TODO(haowen): refactor the code (and support case where sub_eps is empty?) +def print_fst_from_ngrams(ngram_lm, bos='', eos='', sub_eps='#0'): + # for now this version only support case that disambig sybmol is not zero + assert sub_eps != '' + highest_order = len(ngram_lm) + state_id = {bos: 0, '': 1, eos: 2} + state_count = len(state_id) + for order, ngrams in enumerate(ngram_lm, start=1): + for logprob, prev_words, cur_word, backoff in ngrams: + assert len(prev_words) + 1 == order + prev_words_str = ' '.join(prev_words) + whole_words_str = ' '.join(prev_words + [cur_word]) + if prev_words_str not in state_id: + continue # no parent (n-1) gram + source = state_id[prev_words_str] + weight = -logprob + assert cur_word != sub_eps + if cur_word == eos: + if sub_eps == '': + dest = state_id[eos] + else: + # treat as if it was epsilon; mark source final + print(f'{source} {weight}') + continue + else: + key = whole_words_str if order != highest_order else ' '.join( + (prev_words + [cur_word])[1:]) + if key not in state_id: + state_id[key] = state_count + dest = state_count + state_count += 1 + tails = ' '.join((key.split(' '))[1:]) + create_backoff(tails, dest, state_id, -backoff, sub_eps) + else: + dest = state_id[key] + + if cur_word == bos: + weight = 0 + if sub_eps != '': + continue + + print(f'{source} {dest} {cur_word} {weight}') + + if sub_eps == '': + print(f'{state_id[eos]} 0') + + +def main(): + args = get_args() + print_fst_from_ngrams(parse_arpa_file(args.arpa, args.to_natural_log), + args.bos, args.eos, args.disambig_symbol) + + +if __name__ == '__main__': + main() diff --git a/egs/timit/asr/simple_v1/local/convert_transcript_to_corpus.py b/egs/timit/asr/simple_v1/local/convert_transcript_to_corpus.py new file mode 100644 index 00000000..62c60074 --- /dev/null +++ b/egs/timit/asr/simple_v1/local/convert_transcript_to_corpus.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +''' +Convert a transcript file to a corpus for LM training with +the help of a lexicon. If the lexicon contains phones, the resulting +LM will be a phone LM; If the lexicon contains word pieces, +the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, only the first one is used. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +''' + +from pathlib import Path +from typing import Dict + +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--transcript', + type=str, + help='The input transcript file.' + 'We assume that the transcript file consists of ' + 'lines. Each line consists of space separated words.') + parser.add_argument('--lexicon', type=str, help='The input lexicon file.') + parser.add_argument('--oov', + type=str, + default='', + help='The OOV word.') + + return parser.parse_args() + + +def read_lexicon(filename: str) -> Dict[str, str]: + ''' + Args: + filename: + Filename to the lexicon. Each line in the lexicon + has the following format: + + word p1 p2 p3 + + where the first field is a word and the remaining fields + are the pronunciations of the word. Fields are separated + by spaces. + Returns: + Return a dict whose keys are words and values are the pronunciations. + ''' + ans = dict() + with open(filename) as f: + for line in f: + line = line.strip() + + if len(line) == 0: + # skip empty lines + continue + + fields = line.split() + assert len(fields) >= 2 + + word = fields[0] + pron = ' '.join(fields[1:]) + + if word not in ans: + # In case a word has multiple pronunciations, + # we only use the first one + ans[word] = pron + return ans + + +def process_line(lexicon: Dict[str, str], line: str, oov_pron: str) -> None: + ''' + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations. + line: + A line of transcript consisting of space separated words. + oov_pron: + The pronunciation of the oov word if a word in line is not present + in the lexicon. + Returns: + Return None. + ''' + words = line.strip().split() + for i, w in enumerate(words): + pron = lexicon.get(w, oov_pron) + print(pron, end=' ') + if i == len(words) - 1: + # end of the line, prints a new line + print() + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + lexicon = read_lexicon(args.lexicon) + assert args.oov in lexicon + + oov_pron = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_pron=oov_pron) + + +if __name__ == '__main__': + main() diff --git a/egs/timit/asr/simple_v1/local/filter_scp.pl b/egs/timit/asr/simple_v1/local/filter_scp.pl new file mode 100644 index 00000000..b76d37f4 --- /dev/null +++ b/egs/timit/asr/simple_v1/local/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/timit/asr/simple_v1/local/fstaddselfloops.pl b/egs/timit/asr/simple_v1/local/fstaddselfloops.pl new file mode 100644 index 00000000..133a518e --- /dev/null +++ b/egs/timit/asr/simple_v1/local/fstaddselfloops.pl @@ -0,0 +1,44 @@ +#!/usr/bin/env perl + +# Copyright 2020 Xiaomi Corporation (Author: Junbo Zhang) +# Apache 2.0 + +use strict; +use warnings; + +my $Usage = < < + e.g.: cat L_disambig.txt | local/fstaddselfloops.pl 347 200004 > L_disambig_with_loop.txt +EOU + +if (@ARGV != 2) { + die $Usage; +} + +my $wdisambig_phone = shift @ARGV; +my $wdisambig_word = shift @ARGV; + +my %states_needs_self_loops; +while (<>) { + print $_; + + my @items = split(/\s+/); + if (@items == 2) { + # it is a final state + $states_needs_self_loops{$items[0]} = 1; + } elsif (@items == 5) { + my ($src, $dst, $inlabel, $outlabel, $score) = @items; + $states_needs_self_loops{$src} = 1 if ($outlabel != 0); + } else { + die "Invalid openfst line."; + } +} + +foreach (keys %states_needs_self_loops) { + print "$_ $_ $wdisambig_phone $wdisambig_word 0.0\n" +} diff --git a/egs/timit/asr/simple_v1/local/make_kn_lm.py b/egs/timit/asr/simple_v1/local/make_kn_lm.py new file mode 100644 index 00000000..58b721d2 --- /dev/null +++ b/egs/timit/asr/simple_v1/local/make_kn_lm.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 + +# Copyright 2016 Johns Hopkins University (Author: Daniel Povey) +# 2018 Ruizhe Huang +# Apache 2.0. + +# This is an implementation of computing Kneser-Ney smoothed language model +# in the same way as srilm. This is a back-off, unmodified version of +# Kneser-Ney smoothing, which produces the same results as the following +# command (as an example) of srilm: +# +# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \ +# -text corpus.txt -lm lm.arpa +# +# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py +# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html + +import sys +import os +import re +import io +import math +import argparse +from collections import Counter, defaultdict + + +parser = argparse.ArgumentParser(description=""" + Generate kneser-ney language model as arpa format. By default, + it will read the corpus from standard input, and output to standard output. + """) +parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") +parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") +parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") +parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") +args = parser.parse_args() + +default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. + # Need to be very careful about the use of strip() and split() + # in this case, because there is a latin-1 whitespace character + # (nbsp) which is part of the unicode encoding range. + # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 +strip_chars = " \t\r\n" +whitespace = re.compile("[ \t]+") + + +class CountsForHistory: + # This class (which is more like a struct) stores the counts seen in a + # particular history-state. It is used inside class NgramCounts. + # It really does the job of a dict from int to float, but it also + # keeps track of the total count. + def __init__(self): + # The 'lambda: defaultdict(float)' is an anonymous function taking no + # arguments that returns a new defaultdict(float). + self.word_to_count = defaultdict(int) + self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts + self.word_to_f = dict() # discounted probability + self.word_to_bow = dict() # back-off weight + self.total_count = 0 + + def words(self): + return self.word_to_count.keys() + + def __str__(self): + # e.g. returns ' total=12: 3->4, 4->6, -1->2' + return ' total={0}: {1}'.format( + str(self.total_count), + ', '.join(['{0} -> {1}'.format(word, count) + for word, count in self.word_to_count.items()])) + + def add_count(self, predicted_word, context_word, count): + assert count >= 0 + + self.total_count += count + self.word_to_count[predicted_word] += count + if context_word is not None: + self.word_to_context[predicted_word].add(context_word) + + +class NgramCounts: + # A note on data-structure. Firstly, all words are represented as + # integers. We store n-gram counts as an array, indexed by (history-length + # == n-gram order minus one) (note: python calls arrays "lists") of dicts + # from histories to counts, where histories are arrays of integers and + # "counts" are dicts from integer to float. For instance, when + # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd + # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an + # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. + def __init__(self, ngram_order, bos_symbol='', eos_symbol=''): + assert ngram_order >= 2 + + self.ngram_order = ngram_order + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + self.counts = [] + for n in range(ngram_order): + self.counts.append(defaultdict(lambda: CountsForHistory())) + + self.d = [] # list of discounting factor for each order of ngram + + # adds a raw count (called while processing input data). + # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history' + # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be + # 1. + def add_count(self, history, predicted_word, context_word, count): + self.counts[len(history)][history].add_count(predicted_word, context_word, count) + + # 'line' is a string containing a sequence of integer word-ids. + # This function adds the un-smoothed counts from this line of text. + def add_raw_counts_from_line(self, line): + if line == '': + words = [self.bos_symbol, self.eos_symbol] + else: + words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] + + for i in range(len(words)): + for n in range(1, self.ngram_order+1): + if i + n > len(words): + break + ngram = words[i: i + n] + predicted_word = ngram[-1] + history = tuple(ngram[: -1]) + if i == 0 or n == self.ngram_order: + context_word = None + else: + context_word = words[i-1] + + self.add_count(history, predicted_word, context_word, 1) + + def add_raw_counts_from_standard_input(self): + lines_processed = 0 + infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input + for line in infile: + line = line.strip(strip_chars) + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def add_raw_counts_from_file(self, filename): + lines_processed = 0 + with open(filename, encoding=default_encoding) as fp: + for line in fp: + line = line.strip(strip_chars) + self.add_raw_counts_from_line(line) + lines_processed += 1 + if lines_processed == 0 or args.verbose > 0: + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) + + def cal_discounting_constants(self): + # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), + # where n1_N is the number of unique N-grams with count = 1 (counts-of-counts). + # This constant is used similarly to absolute discounting. + # Return value: d is a list of floats, where d[N+1] = D_N + + self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 + # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, + # but perhaps this is not the case for some other scenarios. + for n in range(1, self.ngram_order): + this_order_counts = self.counts[n] + n1 = 0 + n2 = 0 + for hist, counts_for_hist in this_order_counts.items(): + stat = Counter(counts_for_hist.word_to_count.values()) + n1 += stat[1] + n2 += stat[2] + assert n1 + 2 * n2 > 0 + self.d.append(n1 * 1.0 / (n1 + 2 * n2)) + + def cal_f(self): + # f(a_z) is a probability distribution of word sequence a_z. + # Typically f(a_z) is discounted to be less than the ML estimate so we have + # some leftover probability for the z words unseen in the context (a_). + # + # f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams + # f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w, c in counts_for_hist.word_to_count.items(): + counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + + n_star_star = 0 + for w in counts_for_hist.word_to_count.keys(): + n_star_star += len(counts_for_hist.word_to_context[w]) + + if n_star_star != 0: + for w in counts_for_hist.word_to_count.keys(): + n_star_z = len(counts_for_hist.word_to_context[w]) + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star + else: # patterns begin with , they do not have "modified count", so use raw count instead + for w in counts_for_hist.word_to_count.keys(): + n_star_z = counts_for_hist.word_to_count[w] + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count + + def cal_bow(self): + # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. + # Thus, two sorts of ngrams do not have a bow: + # 1) highest order ngram + # 2) ngrams ending in + # + # bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z)) + # Note that Z1 is the set of all words with c(a_z) > 0 + + # highest order N-grams + n = self.ngram_order - 1 + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + counts_for_hist.word_to_bow[w] = None + + # lower order N-grams + for n in range(0, self.ngram_order - 1): + this_order_counts = self.counts[n] + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + if w == self.eos_symbol: + counts_for_hist.word_to_bow[w] = None + else: + a_ = hist + (w,) + + assert len(a_) < self.ngram_order + assert a_ in self.counts[len(a_)].keys() + + a_counts_for_hist = self.counts[len(a_)][a_] + + sum_z1_f_a_z = 0 + for u in a_counts_for_hist.word_to_count.keys(): + sum_z1_f_a_z += a_counts_for_hist.word_to_f[u] + + sum_z1_f_z = 0 + _ = a_[1:] + _counts_for_hist = self.counts[len(_)][_] + for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 + sum_z1_f_z += _counts_for_hist.word_to_f[u] + + counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) + + def print_raw_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) + res.sort(reverse=True) + for r in res: + print(r) + + def print_modified_counts(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + modified_count = len(counts_for_hist.word_to_context[w]) + raw_count = counts_for_hist.word_to_count[w] + + if modified_count == 0: + res.append("{0}\t{1}".format(ngram, raw_count)) + else: + res.append("{0}\t{1}".format(ngram, modified_count)) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + res.append("{0}\t{1}".format(ngram, math.log(f, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_f_and_bow(self, info_string): + # these are useful for debug. + print(info_string) + res = [] + for this_order_counts in self.counts: + for hist, counts_for_hist in this_order_counts.items(): + for w in counts_for_hist.word_to_count.keys(): + ngram = " ".join(hist) + " " + w + ngram = ngram.strip(strip_chars) + + f = counts_for_hist.word_to_f[w] + if f == 0: # f() is always 0 + f = 1e-99 + + bow = counts_for_hist.word_to_bow[w] + if bow is None: + res.append("{1}\t{0}".format(ngram, math.log(f, 10))) + else: + res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) + res.sort(reverse=True) + for r in res: + print(r) + + def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): + # print as ARPA format. + + print('\\data\\', file=fout) + for hist_len in range(self.ngram_order): + # print the number of n-grams. + print('ngram {0}={1}'.format( + hist_len + 1, + sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), + file=fout + ) + + print('', file=fout) + + for hist_len in range(self.ngram_order): + print('\\{0}-grams:'.format(hist_len + 1), file=fout) + + this_order_counts = self.counts[hist_len] + for hist, counts_for_hist in this_order_counts.items(): + for word in counts_for_hist.word_to_count.keys(): + ngram = hist + (word,) + prob = counts_for_hist.word_to_f[word] + bow = counts_for_hist.word_to_bow[word] + + if prob == 0: # f() is always 0 + prob = 1e-99 + + line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) + if bow is not None: + line += '\t{0}'.format('%.7f' % math.log10(bow)) + print(line, file=fout) + print('', file=fout) + print('\\end\\', file=fout) + + +if __name__ == "__main__": + + ngram_counts = NgramCounts(args.ngram_order) + + if args.text is None: + ngram_counts.add_raw_counts_from_standard_input() + else: + assert os.path.isfile(args.text) + ngram_counts.add_raw_counts_from_file(args.text) + + ngram_counts.cal_discounting_constants() + ngram_counts.cal_f() + ngram_counts.cal_bow() + + if args.lm is None: + ngram_counts.print_as_arpa() + else: + with open(args.lm, 'w', encoding=default_encoding) as f: + ngram_counts.print_as_arpa(fout=f) diff --git a/egs/timit/asr/simple_v1/local/make_lexicon_fst.py b/egs/timit/asr/simple_v1/local/make_lexicon_fst.py new file mode 100644 index 00000000..e22222db --- /dev/null +++ b/egs/timit/asr/simple_v1/local/make_lexicon_fst.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0. + +# see get_args() below for usage message. +import argparse +import os +import sys +import math +import re + +# The use of latin-1 encoding does not preclude reading utf-8. latin-1 +# encoding means "treat words as sequences of bytes", and it is compatible +# with utf-8 encoding as well as other encodings such as gbk, as long as the +# spaces are also spaces in ascii (which we check). It is basically how we +# emulate the behavior of python before python3. +sys.stdout = open(1, 'w', encoding='latin-1', closefd=False) +sys.stderr = open(2, 'w', encoding='latin-1', closefd=False) + +def get_args(): + parser = argparse.ArgumentParser(description="""This script creates the + text form of a lexicon FST, to be compiled by fstcompile using the + appropriate symbol tables (phones.txt and words.txt) . It will mostly + be invoked indirectly via utils/prepare_lang.sh. The output goes to + the stdout.""") + + parser.add_argument('--sil-phone', dest='sil_phone', type=str, + help="""Text form of optional-silence phone, e.g. 'SIL'. See also + the --silprob option.""") + parser.add_argument('--sil-prob', dest='sil_prob', type=float, default=0.0, + help="""Probability of silence between words (including at the + beginning and end of word sequences). Must be in the range [0.0, 1.0]. + This refers to the optional silence inserted by the lexicon; see + the --silphone option.""") + parser.add_argument('--sil-disambig', dest='sil_disambig', type=str, + help="""Disambiguation symbol to disambiguate silence, e.g. #5. + Will only be supplied if you are creating the version of L.fst + with disambiguation symbols, intended for use with cyclic G.fst. + This symbol was introduced to fix a rather obscure source of + nondeterminism of CLG.fst, that has to do with reordering of + disambiguation symbols and phone symbols.""") + parser.add_argument('--left-context-phones', dest='left_context_phones', type=str, + help="""Only relevant if --nonterminals is also supplied; this relates + to grammar decoding (see http://kaldi-asr.org/doc/grammar.html or + src/doc/grammar.dox). Format is a list of left-context phones, + in text form, one per line. E.g. data/lang/phones/left_context_phones.txt""") + parser.add_argument('--nonterminals', type=str, + help="""If supplied, --left-context-phones must also be supplied. + List of user-defined nonterminal symbols such as #nonterm:contact_list, + one per line. E.g. data/local/dict/nonterminals.txt.""") + parser.add_argument('lexiconp', type=str, + help="""Filename of lexicon with pronunciation probabilities + (normally lexiconp.txt), with lines of the form 'word prob p1 p2...', + e.g. 'a 1.0 ay'""") + args = parser.parse_args() + return args + + +def read_lexiconp(filename): + """Reads the lexiconp.txt file in 'filename', with lines like 'word pron p1 p2 ...'. + Returns a list of tuples (word, pron_prob, pron), where 'word' is a string, + 'pron_prob', a float, is the pronunciation probability (which must be >0.0 + and would normally be <=1.0), and 'pron' is a list of strings representing phones. + An element in the returned list might be ('hello', 1.0, ['h', 'eh', 'l', 'ow']). + """ + + ans = [] + found_empty_prons = False + found_large_pronprobs = False + # See the comment near the top of this file, RE why we use latin-1. + with open(filename, 'r', encoding='latin-1') as f: + whitespace = re.compile("[ \t]+") + for line in f: + a = whitespace.split(line.strip(" \t\r\n")) + if len(a) < 2: + print("{0}: error: found bad line '{1}' in lexicon file {2} ".format( + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + word = a[0] + if word == "": + # This would clash with the epsilon symbol normally used in OpenFst. + print("{0}: error: found as a word in lexicon file " + "{1}".format(line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + try: + pron_prob = float(a[1]) + except: + print("{0}: error: found bad line '{1}' in lexicon file {2}, 2nd field " + "should be pron-prob".format(sys.argv[0], line.strip(" \t\r\n"), filename), + file=sys.stderr) + sys.exit(1) + prons = a[2:] + if pron_prob <= 0.0: + print("{0}: error: invalid pron-prob in line '{1}' of lexicon file {1} ".format( + sys.argv[0], line.strip(" \t\r\n"), filename), file=sys.stderr) + sys.exit(1) + if len(prons) == 0: + found_empty_prons = True + ans.append( (word, pron_prob, prons) ) + if pron_prob > 1.0: + found_large_pronprobs = True + if found_empty_prons: + print("{0}: warning: found at least one word with an empty pronunciation " + "in lexicon file {1}.".format(sys.argv[0], filename), + file=sys.stderr) + if found_large_pronprobs: + print("{0}: warning: found at least one word with pron-prob >1.0 " + "in {1}".format(sys.argv[0], filename), file=sys.stderr) + + + if len(ans) == 0: + print("{0}: error: found no pronunciations in lexicon file {1}".format( + sys.argv[0], filename), file=sys.stderr) + sys.exit(1) + return ans + + +def write_nonterminal_arcs(start_state, loop_state, next_state, + nonterminals, left_context_phones): + """This function relates to the grammar-decoding setup, see + kaldi-asr.org/doc/grammar.html. It is called from write_fst_no_silence + and write_fst_silence, and writes to the stdout some extra arcs + in the lexicon FST that relate to nonterminal symbols. + See the section "Special symbols in L.fst, + kaldi-asr.org/doc/grammar.html#grammar_special_l. + start_state: the start-state of L.fst. + loop_state: the state of high out-degree in L.fst where words leave + and enter. + next_state: the number from which this function can start allocating its + own states. the updated value of next_state will be returned. + nonterminals: the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + left_context_phones: a list of phones that may appear as left-context, + e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + shared_state = next_state + next_state += 1 + final_state = next_state + next_state += 1 + + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=start_state, dest=shared_state, + phone='#nonterm_begin', word='#nonterm_begin', + cost=0.0)) + + for nonterminal in nonterminals: + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=loop_state, dest=shared_state, + phone=nonterminal, word=nonterminal, + cost=0.0)) + # this_cost equals log(len(left_context_phones)) but the expression below + # better captures the meaning. Applying this cost to arcs keeps the FST + # stochatic (sum-to-one, like an HMM), so that if we do weight pushing + # things won't get weird. In the grammar-FST code when we splice things + # together we will cancel out this cost, see the function CombineArcs(). + this_cost = -math.log(1.0 / len(left_context_phones)) + + for left_context_phone in left_context_phones: + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=shared_state, dest=loop_state, + phone=left_context_phone, word='', cost=this_cost)) + # arc from loop-state to a final-state with #nonterm_end as ilabel and olabel + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=loop_state, dest=final_state, + phone='#nonterm_end', word='#nonterm_end', cost=0.0)) + print("{state}\t{final_cost}".format( + state=final_state, final_cost=0.0)) + return next_state + + + +def write_fst_no_silence(lexicon, nonterminals=None, left_context_phones=None): + """Writes the text format of L.fst to the standard output. This version is for + when --sil-prob=0.0, meaning there is no optional silence allowed. + + 'lexicon' is a list of 3-tuples (word, pron-prob, prons) as returned by + read_lexiconp(). + 'nonterminals', which relates to grammar decoding (see kaldi-asr.org/doc/grammar.html), + is either None, or the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + 'left_context_phones', which also relates to grammar decoding, and must be + supplied if 'nonterminals' is supplied is either None or a list of + phones that may appear as left-context, e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + + loop_state = 0 + next_state = 1 # the next un-allocated state, will be incremented as we go. + for (word, pronprob, pron) in lexicon: + cost = -math.log(pronprob) + cur_state = loop_state + for i in range(len(pron) - 1): + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=next_state, + phone=pron[i], + word=(word if i == 0 else ''), + cost=(cost if i == 0 else 0.0))) + cur_state = next_state + next_state += 1 + + i = len(pron) - 1 # note: i == -1 if pron is empty. + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=loop_state, + phone=(pron[i] if i >= 0 else ''), + word=(word if i <= 0 else ''), + cost=(cost if i <= 0 else 0.0))) + + if nonterminals is not None: + next_state = write_nonterminal_arcs( + loop_state, loop_state, next_state, + nonterminals, left_context_phones) + + print("{state}\t{final_cost}".format( + state=loop_state, + final_cost=0.0)) + + +def write_fst_with_silence(lexicon, sil_prob, sil_phone, sil_disambig, + nonterminals=None, left_context_phones=None): + """Writes the text format of L.fst to the standard output. This version is for + when --sil-prob != 0.0, meaning there is optional silence + 'lexicon' is a list of 3-tuples (word, pron-prob, prons) + as returned by read_lexiconp(). + 'sil_prob', which is expected to be strictly between 0.. and 1.0, is the + probability of silence + 'sil_phone' is the silence phone, e.g. "SIL". + 'sil_disambig' is either None, or the silence disambiguation symbol, e.g. "#5". + 'nonterminals', which relates to grammar decoding (see kaldi-asr.org/doc/grammar.html), + is either None, or the user-defined nonterminal symbols as a list of + strings, e.g. ['#nonterm:contact_list', ... ]. + 'left_context_phones', which also relates to grammar decoding, and must be + supplied if 'nonterminals' is supplied is either None or a list of + phones that may appear as left-context, e.g. ['a', 'ah', ... '#nonterm_bos']. + """ + + assert sil_prob > 0.0 and sil_prob < 1.0 + sil_cost = -math.log(sil_prob) + no_sil_cost = -math.log(1.0 - sil_prob); + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + + + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=start_state, dest=loop_state, + phone='', word='', cost=no_sil_cost)) + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=start_state, dest=sil_state, + phone='', word='', cost=sil_cost)) + if sil_disambig is None: + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=sil_state, dest=loop_state, + phone=sil_phone, word='', cost=0.0)) + else: + sil_disambig_state = next_state + next_state += 1 + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=sil_state, dest=sil_disambig_state, + phone=sil_phone, word='', cost=0.0)) + print('{src}\t{dest}\t{phone}\t{word}\t{cost}'.format( + src=sil_disambig_state, dest=loop_state, + phone=sil_disambig, word='', cost=0.0)) + + + for (word, pronprob, pron) in lexicon: + pron_cost = -math.log(pronprob) + cur_state = loop_state + for i in range(len(pron) - 1): + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, dest=next_state, + phone=pron[i], + word=(word if i == 0 else ''), + cost=(pron_cost if i == 0 else 0.0))) + cur_state = next_state + next_state += 1 + + i = len(pron) - 1 # note: i == -1 if pron is empty. + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=loop_state, + phone=(pron[i] if i >= 0 else ''), + word=(word if i <= 0 else ''), + cost=no_sil_cost + (pron_cost if i <= 0 else 0.0))) + print("{src}\t{dest}\t{phone}\t{word}\t{cost}".format( + src=cur_state, + dest=sil_state, + phone=(pron[i] if i >= 0 else ''), + word=(word if i <= 0 else ''), + cost=sil_cost + (pron_cost if i <= 0 else 0.0))) + + if nonterminals is not None: + next_state = write_nonterminal_arcs( + start_state, loop_state, next_state, + nonterminals, left_context_phones) + + print("{state}\t{final_cost}".format( + state=loop_state, + final_cost=0.0)) + + + + +def write_words_txt(orig_lines, highest_numbered_symbol, nonterminals, filename): + """Writes updated words.txt to 'filename'. 'orig_lines' is the original lines + in the words.txt file as a list of strings (without the newlines); + highest_numbered_symbol is the highest numbered symbol in the original + words.txt; nonterminals is a list of strings like '#nonterm:foo'.""" + with open(filename, 'w', encoding='latin-1') as f: + for l in orig_lines: + print(l, file=f) + cur_symbol = highest_numbered_symbol + 1 + for n in [ '#nonterm_begin', '#nonterm_end' ] + nonterminals: + print("{0} {1}".format(n, cur_symbol), file=f) + cur_symbol = cur_symbol + 1 + + +def read_nonterminals(filename): + """Reads the user-defined nonterminal symbols in 'filename', checks that + it has the expected format and has no duplicates, and returns the nonterminal + symbols as a list of strings, e.g. + ['#nonterm:contact_list', '#nonterm:phone_number', ... ]. """ + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no nonterminals symbols.".format(filename)) + for nonterm in ans: + if nonterm[:9] != '#nonterm:': + raise RuntimeError("In file '{0}', expected nonterminal symbols to start with '#nonterm:', found '{1}'" + .format(filename, nonterm)) + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + +def read_left_context_phones(filename): + """Reads, checks, and returns a list of left-context phones, in text form, one + per line. Returns a list of strings, e.g. ['a', 'ah', ..., '#nonterm_bos' ]""" + ans = [line.strip(" \t\r\n") for line in open(filename, 'r', encoding='latin-1')] + if len(ans) == 0: + raise RuntimeError("The file {0} contains no left-context phones.".format(filename)) + whitespace = re.compile("[ \t]+") + for s in ans: + if len(whitespace.split(s)) != 1: + raise RuntimeError("The file {0} contains an invalid line '{1}'".format(filename, s) ) + + if len(set(ans)) != len(ans): + raise RuntimeError("Duplicate nonterminal symbols are present in file {0}".format(filename)) + return ans + + +def is_token(s): + """Returns true if s is a string and is space-free.""" + if not isinstance(s, str): + return False + whitespace = re.compile("[ \t\r\n]+") + split_str = whitespace.split(s); + return len(split_str) == 1 and s == split_str[0] + + +def main(): + args = get_args() + + lexicon = read_lexiconp(args.lexiconp) + + if args.nonterminals is None: + nonterminals, left_context_phones = None, None + else: + if args.left_context_phones is None: + print("{0}: if --nonterminals is specified, --left-context-phones must also " + "be specified".format(sys.argv[0])) + sys.exit(1) + nonterminals = read_nonterminals(args.nonterminals) + left_context_phones = read_left_context_phones(args.left_context_phones) + + if args.sil_prob == 0.0: + write_fst_no_silence(lexicon, + nonterminals=nonterminals, + left_context_phones=left_context_phones) + else: + # Do some checking that the options make sense. + if args.sil_prob < 0.0 or args.sil_prob >= 1.0: + print("{0}: invalid value specified --sil-prob={1}".format( + sys.argv[0], args.sil_prob), file=sys.stderr) + sys.exit(1) + + if not is_token(args.sil_phone): + print("{0}: you specified --sil-prob={1} but --sil-phone is set " + "to '{2}'".format(sys.argv[0], args.sil_prob, args.sil_phone), + file=sys.stderr) + sys.exit(1) + if args.sil_disambig is not None and not is_token(args.sil_disambig): + print("{0}: invalid value --sil-disambig='{1}' was specified." + "".format(sys.argv[0], args.sil_disambig), file=sys.stderr) + sys.exit(1) + write_fst_with_silence(lexicon, args.sil_prob, args.sil_phone, + args.sil_disambig, + nonterminals=nonterminals, + left_context_phones=left_context_phones) + + + +# (lines, highest_symbol) = read_words_txt(args.input_words_txt) +# nonterminals = read_nonterminals(args.nonterminal_symbols_list) +# write_words_txt(lines, highest_symbol, nonterminals, args.output_words_txt) + + +if __name__ == '__main__': + main() diff --git a/egs/timit/asr/simple_v1/local/parse_options.sh b/egs/timit/asr/simple_v1/local/parse_options.sh new file mode 100644 index 00000000..71fb9e5e --- /dev/null +++ b/egs/timit/asr/simple_v1/local/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/timit/asr/simple_v1/local/prepare_lang.sh b/egs/timit/asr/simple_v1/local/prepare_lang.sh new file mode 100644 index 00000000..7faf63fc --- /dev/null +++ b/egs/timit/asr/simple_v1/local/prepare_lang.sh @@ -0,0 +1,382 @@ +#!/usr/bin/env bash +# Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal +# 2014 Guoguo Chen +# 2015 Hainan Xu +# 2016 FAU Erlangen (Author: Axel Horndasch) +# 2020 Xiaomi Corporation (Author: Junbo Zhang) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This script prepares a directory such as data/lang/, in the standard format, +# given a source directory containing a dictionary lexicon.txt in a form like: +# word phone1 phone2 ... phoneN +# per line (alternate prons would be separate lines), or a dictionary with probabilities +# called lexiconp.txt in a form: +# word pron-prob phone1 phone2 ... phoneN +# (with 0.0 < pron-prob <= 1.0); note: if lexiconp.txt exists, we use it even if +# lexicon.txt exists. +# and also files silence_phones.txt, nonsilence_phones.txt, optional_silence.txt +# and extra_questions.txt +# Here, silence_phones.txt and nonsilence_phones.txt are lists of silence and +# non-silence phones respectively (where silence includes various kinds of +# noise, laugh, cough, filled pauses etc., and nonsilence phones includes the +# "real" phones.) +# In each line of those files is a list of phones, and the phones on each line +# are assumed to correspond to the same "base phone", i.e. they will be +# different stress or tone variations of the same basic phone. +# The file "optional_silence.txt" contains just a single phone (typically SIL) +# which is used for optional silence in the lexicon. +# extra_questions.txt might be empty; typically will consist of lists of phones, +# all members of each list with the same stress or tone; and also possibly a +# list for the silence phones. This will augment the automatically generated +# questions (note: the automatically generated ones will treat all the +# stress/tone versions of a phone the same, so will not "get to ask" about +# stress or tone). +# + +# This script adds word-position-dependent phones and constructs a host of other +# derived files, that go in data/lang/. + +# Begin configuration section. +num_sil_states=5 +num_nonsil_states=3 +position_dependent_phones=true +# position_dependent_phones is false also when position dependent phones and word_boundary.txt +# have been generated by another source +share_silence_phones=false # if true, then share pdfs of different silence + # phones together. +sil_prob=0.5 +num_extra_phone_disambig_syms=1 # Standard one phone disambiguation symbol is used for optional silence. + # Increasing this number does not harm, but is only useful if you later + # want to introduce this labels to L_disambig.fst + + +# end configuration sections + +echo "$0 $@" # Print the command line for logging + +. local/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: local/prepare_lang.sh " + echo "e.g.: local/prepare_lang.sh data/local/dict data/local/lang data/lang" + echo " should contain the following files:" + echo " extra_questions.txt lexicon.txt nonsilence_phones.txt optional_silence.txt silence_phones.txt" + echo "See http://kaldi-asr.org/doc/data_prep.html#data_prep_lang_creating for more info." + echo "options: " + echo " may also, for the grammar-decoding case (see http://kaldi-asr.org/doc/grammar.html)" + echo "contain a file nonterminals.txt containing symbols like #nonterm:contact_list, one per line." + echo " --num-sil-states # default: 5, #states in silence models." + echo " --num-nonsil-states # default: 3, #states in non-silence models." + echo " --position-dependent-phones (true|false) # default: true; if true, use _B, _E, _S & _I" + echo " # markers on phones to indicate word-internal positions. " + echo " --share-silence-phones (true|false) # default: false; if true, share pdfs of " + echo " # all silence phones. " + echo " --sil-prob # default: 0.5 [must have 0 <= silprob < 1]" + exit 1; +fi + +srcdir=$1 +oov_word=$2 +tmpdir=$3 +dir=$4 + + +if [ -d $dir/phones ]; then + rm -r $dir/phones +fi +mkdir -p $dir $tmpdir $dir/phones + +silprob=false +[ -f $srcdir/lexiconp_silprob.txt ] && silprob=true + +[ -f path.sh ] && . ./path.sh + +if [[ ! -f $srcdir/lexicon.txt ]]; then + echo "**Creating $srcdir/lexicon.txt from $srcdir/lexiconp.txt" + perl -ape 's/(\S+\s+)\S+\s+(.+)/$1$2/;' < $srcdir/lexiconp.txt > $srcdir/lexicon.txt || exit 1; +fi +if [[ ! -f $srcdir/lexiconp.txt ]]; then + echo "**Creating $srcdir/lexiconp.txt from $srcdir/lexicon.txt" + perl -ape 's/(\S+\s+)(.+)/${1}1.0\t$2/;' < $srcdir/lexicon.txt > $srcdir/lexiconp.txt || exit 1; +fi + +if [ ! -z "$unk_fst" ] && [ ! -f "$unk_fst" ]; then + echo "$0: expected --unk-fst $unk_fst to exist as a file" + exit 1 +fi + +if $position_dependent_phones; then + # Create $tmpdir/lexiconp.txt from $srcdir/lexiconp.txt (or + # $tmpdir/lexiconp_silprob.txt from $srcdir/lexiconp_silprob.txt) by + # adding the markers _B, _E, _S, _I depending on word position. + # In this recipe, these markers apply to silence also. + # Do this starting from lexiconp.txt only. + if "$silprob"; then + perl -ane '@A=split(" ",$_); $w = shift @A; $p = shift @A; $silword_p = shift @A; + $wordsil_f = shift @A; $wordnonsil_f = shift @A; @A>0||die; + if(@A==1) { print "$w $p $silword_p $wordsil_f $wordnonsil_f $A[0]_S\n"; } + else { print "$w $p $silword_p $wordsil_f $wordnonsil_f $A[0]_B "; + for($n=1;$n<@A-1;$n++) { print "$A[$n]_I "; } print "$A[$n]_E\n"; } ' \ + < $srcdir/lexiconp_silprob.txt > $tmpdir/lexiconp_silprob.txt + else + perl -ane '@A=split(" ",$_); $w = shift @A; $p = shift @A; @A>0||die; + if(@A==1) { print "$w $p $A[0]_S\n"; } else { print "$w $p $A[0]_B "; + for($n=1;$n<@A-1;$n++) { print "$A[$n]_I "; } print "$A[$n]_E\n"; } ' \ + < $srcdir/lexiconp.txt > $tmpdir/lexiconp.txt || exit 1; + fi + + # create $tmpdir/phone_map.txt + # this has the format (on each line) + # ... + # where the versions depend on the position of the phone within a word. + # For instance, we'd have: + # AA AA_B AA_E AA_I AA_S + # for (B)egin, (E)nd, (I)nternal and (S)ingleton + # and in the case of silence + # SIL SIL SIL_B SIL_E SIL_I SIL_S + # [because SIL on its own is one of the variants; this is for when it doesn't + # occur inside a word but as an option in the lexicon.] + + # This phone map expands the phone lists into all the word-position-dependent + # versions of the phone lists. + cat <(set -f; for x in `cat $srcdir/silence_phones.txt`; do for y in "" "" "_B" "_E" "_I" "_S"; do echo -n "$x$y "; done; echo; done) \ + <(set -f; for x in `cat $srcdir/nonsilence_phones.txt`; do for y in "" "_B" "_E" "_I" "_S"; do echo -n "$x$y "; done; echo; done) \ + > $tmpdir/phone_map.txt +else + if "$silprob"; then + cp $srcdir/lexiconp_silprob.txt $tmpdir/lexiconp_silprob.txt + else + cp $srcdir/lexiconp.txt $tmpdir/lexiconp.txt + fi + + cat $srcdir/silence_phones.txt $srcdir/nonsilence_phones.txt | \ + awk '{for(n=1;n<=NF;n++) print $n; }' > $tmpdir/phones + paste -d' ' $tmpdir/phones $tmpdir/phones > $tmpdir/phone_map.txt +fi + + +# Making monophone systems. +cat $srcdir/silence_phones.txt | local/apply_map.pl $tmpdir/phone_map.txt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' > $dir/phones/silence.txt +cat $srcdir/nonsilence_phones.txt | local/apply_map.pl $tmpdir/phone_map.txt | \ + awk '{for(n=1;n<=NF;n++) print $n;}' > $dir/phones/nonsilence.txt +cp $srcdir/optional_silence.txt $dir/phones/optional_silence.txt + +# if extra_questions.txt is empty, it's OK. +cat $srcdir/extra_questions.txt 2>/dev/null | local/apply_map.pl $tmpdir/phone_map.txt \ + >$dir/phones/extra_questions.txt + +# Want extra questions about the word-start/word-end stuff. Make it separate for +# silence and non-silence. Probably doesn't matter, as silence will rarely +# be inside a word. +if $position_dependent_phones; then + for suffix in _B _E _I _S; do + (set -f; for x in `cat $srcdir/nonsilence_phones.txt`; do echo -n "$x$suffix "; done; echo) >>$dir/phones/extra_questions.txt + done + for suffix in "" _B _E _I _S; do + (set -f; for x in `cat $srcdir/silence_phones.txt`; do echo -n "$x$suffix "; done; echo) >>$dir/phones/extra_questions.txt + done +fi + +# add_lex_disambig.pl is responsible for adding disambiguation symbols to +# the lexicon, for telling us how many disambiguation symbols it used, +# and also for modifying the unknown-word's pronunciation (if the +# --unk-fst was provided) to the sequence "#1 #2 #3", and reserving those +# disambig symbols for that purpose. +# The #2 will later be replaced with the actual unk model. The reason +# for the #1 and the #3 is for disambiguation and also to keep the +# FST compact. If we didn't have the #1, we might have a different copy of +# the unk-model FST, or at least some of its arcs, for each start-state from +# which an transition comes (instead of per end-state, which is more compact); +# and adding the #3 prevents us from potentially having 2 copies of the unk-model +# FST due to the optional-silence [the last phone of any word gets 2 arcs]. +if [ ! -z "$unk_fst" ]; then # if the --unk-fst option was provided... + if "$silprob"; then + local/lang/internal/modify_unk_pron.py $tmpdir/lexiconp_silprob.txt "$oov_word" || exit 1 + else + local/lang/internal/modify_unk_pron.py $tmpdir/lexiconp.txt "$oov_word" || exit 1 + fi + unk_opt="--first-allowed-disambig 4" +else + unk_opt= +fi + +if "$silprob"; then + ndisambig=$(local/add_lex_disambig.pl $unk_opt --pron-probs --sil-probs $tmpdir/lexiconp_silprob.txt $tmpdir/lexiconp_silprob_disambig.txt) +else + ndisambig=$(local/add_lex_disambig.pl $unk_opt --pron-probs $tmpdir/lexiconp.txt $tmpdir/lexiconp_disambig.txt) +fi +ndisambig=$[$ndisambig+$num_extra_phone_disambig_syms]; # add (at least) one disambig symbol for silence in lexicon FST. +echo $ndisambig > $tmpdir/lex_ndisambig + +# Format of lexiconp_disambig.txt: +# !SIL 1.0 SIL_S +# 1.0 SPN_S #1 +# 1.0 SPN_S #2 +# 1.0 NSN_S +# !EXCLAMATION-POINT 1.0 EH2_B K_I S_I K_I L_I AH0_I M_I EY1_I SH_I AH0_I N_I P_I OY2_I N_I T_E + +( for n in `seq 0 $ndisambig`; do echo '#'$n; done ) >$dir/phones/disambig.txt + +# Create phone symbol table. +echo "" | cat - $dir/phones/{silence,nonsilence,disambig}.txt | \ + awk '{n=NR-1; print $1, n;}' > $dir/phones.txt + +# Create a file that describes the word-boundary information for +# each phone. 5 categories. +if $position_dependent_phones; then + cat $dir/phones/{silence,nonsilence}.txt | \ + awk '/_I$/{print $1, "internal"; next;} /_B$/{print $1, "begin"; next; } + /_S$/{print $1, "singleton"; next;} /_E$/{print $1, "end"; next; } + {print $1, "nonword";} ' > $dir/phones/word_boundary.txt +else + # word_boundary.txt might have been generated by another source + [ -f $srcdir/word_boundary.txt ] && cp $srcdir/word_boundary.txt $dir/phones/word_boundary.txt +fi + +# Create word symbol table. +# and are only needed due to the need to rescore lattices with +# ConstArpaLm format language model. They do not normally appear in G.fst or +# L.fst. + +if "$silprob"; then + # remove the silprob + cat $tmpdir/lexiconp_silprob.txt |\ + awk '{ + for(i=1; i<=NF; i++) { + if(i!=3 && i!=4 && i!=5) printf("%s\t", $i); if(i==NF) print ""; + } + }' > $tmpdir/lexiconp.txt +fi + +cat $tmpdir/lexiconp.txt | awk '{print $1}' | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $dir/words.txt || exit 1; + +# format of $dir/words.txt: +# 0 +#a 1 +#aa 2 +#aarvark 3 +#... + +silphone=`cat $srcdir/optional_silence.txt` || exit 1; +[ -z "$silphone" ] && \ + ( echo "You have no optional-silence phone; it is required in the current scripts" + echo "but you may use the option --sil-prob 0.0 to stop it being used." ) && \ + exit 1; + +grammar_opts= + +# Create the basic L.fst without disambiguation symbols, for use +# in training. + +if $silprob; then + # Add silence probabilities (models the prob. of silence before and after each + # word). On some setups this helps a bit. See local/dict_dir_add_pronprobs.sh + # and where it's called in the example scripts (run.sh). + local/make_lexicon_fst_silprob.py $grammar_opts --sil-phone=$silphone \ + $tmpdir/lexiconp_silprob.txt $srcdir/silprob.txt | \ + local/sym2int.pl -f 3 $dir/phones.txt | \ + local/sym2int.pl -f 4 $dir/words.txt > $dir/L.fst.txt || exit 1; + + # fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ + # --keep_isymbols=false --keep_osymbols=false | \ + # fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; +else + local/make_lexicon_fst.py $grammar_opts --sil-prob=$sil_prob --sil-phone=$silphone \ + $tmpdir/lexiconp.txt | \ + local/sym2int.pl -f 3 $dir/phones.txt | \ + local/sym2int.pl -f 4 $dir/words.txt > $dir/L.fst.txt || exit 1; + + # fstcompile --isymbols=$dir/phones.txt --osymbols=$dir/words.txt \ + # --keep_isymbols=false --keep_osymbols=false | \ + # fstarcsort --sort_type=olabel > $dir/L.fst || exit 1; +fi + +# The file oov.txt contains a word that we will map any OOVs to during +# training. +echo "$oov_word" > $dir/oov.txt || exit 1; +cat $dir/oov.txt | local/sym2int.pl $dir/words.txt >$dir/oov.int || exit 1; +# integer version of oov symbol, used in some scripts. + + +# the file wdisambig.txt contains a (line-by-line) list of the text-form of the +# disambiguation symbols that are used in the grammar and passed through by the +# lexicon. At this stage it's hardcoded as '#0', but we're laying the groundwork +# for more generality (which probably would be added by another script). +# wdisambig_words.int contains the corresponding list interpreted by the +# symbol table words.txt, and wdisambig_phones.int contains the corresponding +# list interpreted by the symbol table phones.txt. +echo '#0' >$dir/phones/wdisambig.txt + +wdisambig_phone=`local/sym2int.pl $dir/phones.txt <$dir/phones/wdisambig.txt` +wdisambig_word=`local/sym2int.pl $dir/words.txt <$dir/phones/wdisambig.txt` + +# Create these lists of phones in colon-separated integer list form too, +# for purposes of being given to programs as command-line options. +for f in silence nonsilence optional_silence disambig; do + local/sym2int.pl $dir/phones.txt <$dir/phones/$f.txt >$dir/phones/$f.int + local/sym2int.pl $dir/phones.txt <$dir/phones/$f.txt | \ + awk '{printf(":%d", $1);} END{printf "\n"}' | sed s/:// > $dir/phones/$f.csl || exit 1; +done + +if [ -f $dir/phones/word_boundary.txt ]; then + local/sym2int.pl -f 1 $dir/phones.txt <$dir/phones/word_boundary.txt \ + > $dir/phones/word_boundary.int || exit 1; +fi + +silphonelist=`cat $dir/phones/silence.csl` +nonsilphonelist=`cat $dir/phones/nonsilence.csl` + +# Create the lexicon FST with disambiguation symbols, and put it in lang_test. +# There is an extra step where we create a loop to "pass through" the +# disambiguation symbols from G.fst. + +if $silprob; then + local/make_lexicon_fst_silprob.py $grammar_opts \ + --sil-phone=$silphone --sil-disambig='#'$ndisambig \ + $tmpdir/lexiconp_silprob_disambig.txt $srcdir/silprob.txt | \ + local/sym2int.pl -f 3 $dir/phones.txt | \ + local/sym2int.pl -f 4 $dir/words.txt | \ + local/fstaddselfloops.pl $wdisambig_phone $wdisambig_word > $dir/L_disambig.fst.txt || exit 1; +else + local/make_lexicon_fst.py $grammar_opts \ + --sil-prob=$sil_prob --sil-phone=$silphone --sil-disambig='#'$ndisambig \ + $tmpdir/lexiconp_disambig.txt | \ + local/sym2int.pl -f 3 $dir/phones.txt | \ + local/sym2int.pl -f 4 $dir/words.txt | \ + local/fstaddselfloops.pl $wdisambig_phone $wdisambig_word > $dir/L_disambig.fst.txt || exit 1; +fi + +exit 0; diff --git a/egs/timit/asr/simple_v1/local/sym2int.pl b/egs/timit/asr/simple_v1/local/sym2int.pl new file mode 100644 index 00000000..58167200 --- /dev/null +++ b/egs/timit/asr/simple_v1/local/sym2int.pl @@ -0,0 +1,101 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +$ignore_oov = 0; + +for($x = 0; $x < 2; $x++) { + if ($ARGV[0] eq "--map-oov") { + shift @ARGV; + $map_oov = shift @ARGV; + if ($map_oov eq "-f" || $map_oov =~ m/words\.txt$/ || $map_oov eq "") { + # disallow '-f', the empty string and anything ending in words.txt as the + # OOV symbol because these are likely command-line errors. + die "the --map-oov option requires an argument"; + } + } + if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } + } +} + +$symtab = shift @ARGV; +if (!defined $symtab) { + print STDERR "Usage: sym2int.pl [options] symtab [input transcriptions] > output transcriptions\n" . + "options: [--map-oov ] [-f ]\n" . + "note: can look like 4-5, or 4-, or 5-, or 1.\n"; +} +open(F, "<$symtab") || die "Error opening symbol table file $symtab"; +while() { + @A = split(" ", $_); + @A == 2 || die "bad line in symbol table file: $_"; + $sym2int{$A[0]} = $A[1] + 0; +} + +if (defined $map_oov && $map_oov !~ m/^\d+$/) { # not numeric-> look it up + if (!defined $sym2int{$map_oov}) { die "OOV symbol $map_oov not defined."; } + $map_oov = $sym2int{$map_oov}; +} + +$num_warning = 0; +$max_warning = 20; + +while (<>) { + @A = split(" ", $_); + @B = (); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if ( (!defined $field_begin || $n >= $field_begin) + && (!defined $field_end || $n <= $field_end)) { + $i = $sym2int{$a}; + if (!defined ($i)) { + if (defined $map_oov) { + if ($num_warning++ < $max_warning) { + print STDERR "sym2int.pl: replacing $a with $map_oov\n"; + if ($num_warning == $max_warning) { + print STDERR "sym2int.pl: not warning for OOVs any more times\n"; + } + } + $i = $map_oov; + } + } + $a = $i; + } + push @B, $a; + } + print join(" ", @B); + print "\n"; +} +if ($num_warning > 0) { + print STDERR "** Replaced $num_warning instances of OOVs with $map_oov\n"; +} + +exit(0); diff --git a/egs/timit/asr/simple_v1/local2/phones.60-48-39.map b/egs/timit/asr/simple_v1/local2/phones.60-48-39.map new file mode 100644 index 00000000..4ebcc140 --- /dev/null +++ b/egs/timit/asr/simple_v1/local2/phones.60-48-39.map @@ -0,0 +1,61 @@ +aa aa aa +ae ae ae +ah ah ah +ao ao aa +aw aw aw +ax ax ah +ax-h ax ah +axr er er +ay ay ay +b b b +bcl vcl sil +ch ch ch +d d d +dcl vcl sil +dh dh dh +dx dx dx +eh eh eh +el el l +em m m +en en n +eng ng ng +epi epi sil +er er er +ey ey ey +f f f +g g g +gcl vcl sil +h# sil sil +hh hh hh +hv hh hh +ih ih ih +ix ix ih +iy iy iy +jh jh jh +k k k +kcl cl sil +l l l +m m m +n n n +ng ng ng +nx n n +ow ow ow +oy oy oy +p p p +pau sil sil +pcl cl sil +q +r r r +s s s +sh sh sh +t t t +tcl cl sil +th th th +uh uh uh +uw uw uw +ux uw uw +v v v +w w w +y y y +z z z +zh zh sh diff --git a/egs/timit/asr/simple_v1/local2/timit_data_prep.sh b/egs/timit/asr/simple_v1/local2/timit_data_prep.sh new file mode 100644 index 00000000..c7b3e391 --- /dev/null +++ b/egs/timit/asr/simple_v1/local2/timit_data_prep.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +# Copyright 2013 (Authors: Bagher BabaAli, Daniel Povey, Arnab Ghoshal) +# 2014 Brno University of Technology (Author: Karel Vesely) +# Apache 2.0. + +if [ $# -ne 1 ]; then + echo "Argument should be the Timit directory, see ../run.sh for example." + exit 1; +fi + +dir=`pwd`/data/local/data +mkdir -p $dir +local=`pwd`/local +utils=`pwd`/utils +conf=`pwd`/local2 + +. ./path.sh + +# First check if the train & test directories exist (these can either be upper- +# or lower-cased +if [ ! -d $*/TRAIN -o ! -d $*/TEST ] && [ ! -d $*/train -o ! -d $*/test ]; then + echo "timit_data_prep.sh: Spot check of command line argument failed" + echo "Command line argument must be absolute pathname to TIMIT directory" + echo "with name like /export/corpora5/LDC/LDC93S1/timit/TIMIT" + exit 1; +fi + +# Now check what case the directory structure is +uppercased=false +train_dir=train +test_dir=test +if [ -d $*/TRAIN ]; then + uppercased=true + train_dir=TRAIN + test_dir=TEST +fi + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT + +# Get the list of speakers. The list of speakers in the 24-speaker core test +# set and the 50-speaker development set must be supplied to the script. All +# speakers in the 'train' directory are used for training. +if $uppercased; then + ls -d "$*"/TRAIN/DR*/* | sed -e "s:^.*/::" > $tmpdir/train_spk +else + ls -d "$*"/train/dr*/* | sed -e "s:^.*/::" > $tmpdir/train_spk +fi + +cd $dir +for x in train; do + # First, find the list of audio files (use only si & sx utterances). + # Note: train & test sets are under different directories, but doing find on + # both and grepping for the speakers will work correctly. + find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.WAV' \ + | grep -f $tmpdir/${x}_spk > ${x}_sph.flist + + sed -e 's:.*/\(.*\)/\(.*\).\(WAV\|wav\)$:\1_\2:' ${x}_sph.flist \ + > $tmpdir/${x}_sph.uttids + paste $tmpdir/${x}_sph.uttids ${x}_sph.flist \ + | sort -k1,1 > ${x}_sph.scp + + cat ${x}_sph.scp | awk '{print $1}' > ${x}.uttids + + # Now, Convert the transcripts into our format (no normalization yet) + # Get the transcripts: each line of the output contains an utterance + # ID followed by the transcript. + find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.PHN' \ + | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_phn.flist + sed -e 's:.*/\(.*\)/\(.*\).\(PHN\|phn\)$:\1_\2:' $tmpdir/${x}_phn.flist \ + > $tmpdir/${x}_phn.uttids + while read line; do + [ -f $line ] || error_exit "Cannot find transcription file '$line'"; + cut -f3 -d' ' "$line" | tr '\n' ' ' | perl -ape 's: *$:\n:;' + done < $tmpdir/${x}_phn.flist > $tmpdir/${x}_phn.trans + paste $tmpdir/${x}_phn.uttids $tmpdir/${x}_phn.trans \ + | sort -k1,1 > ${x}.trans + + # Do normalization steps. + cat ${x}.trans | $conf/timit_norm_trans.pl -i - -m $conf/phones.60-48-39.map -to 48 | sort > $x.text || exit 1; + +done + +echo "Data preparation succeeded" + diff --git a/egs/timit/asr/simple_v1/local2/timit_norm_trans.pl b/egs/timit/asr/simple_v1/local2/timit_norm_trans.pl new file mode 100644 index 00000000..7394323c --- /dev/null +++ b/egs/timit/asr/simple_v1/local2/timit_norm_trans.pl @@ -0,0 +1,92 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter + +# Copyright 2012 Arnab Ghoshal + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script normalizes the TIMIT phonetic transcripts that have been +# extracted in a format where each line contains an utterance ID followed by +# the transcript, e.g.: +# fcke0_si1111 h# hh ah dx ux w iy dcl d ix f ay n ih q h# + +my $usage = "Usage: timit_norm_trans.pl -i transcript -m phone_map -from [60|48] -to [48|39] > normalized\n +Normalizes phonetic transcriptions for TIMIT, by mapping the phones to a +smaller set defined by the -m option. This script assumes that the mapping is +done in the \"standard\" fashion, i.e. to 48 or 39 phones. The input is +assumed to have 60 phones (+1 for glottal stop, which is deleted), but that can +be changed using the -from option. The input format is assumed to be utterance +ID followed by transcript on the same line.\n"; + +use strict; +use Getopt::Long; +die "$usage" unless(@ARGV >= 1); +my ($in_trans, $phone_map, $num_phones_out); +my $num_phones_in = 60; +GetOptions ("i=s" => \$in_trans, # Input transcription + "m=s" => \$phone_map, # File containing phone mappings + "from=i" => \$num_phones_in, # Input #phones: must be 60 or 48 + "to=i" => \$num_phones_out ); # Output #phones: must be 48 or 39 + +die $usage unless(defined($in_trans) && defined($phone_map) && + defined($num_phones_out)); +if ($num_phones_in != 60 && $num_phones_in != 48) { + die "Can only used 60 or 48 for -from (used $num_phones_in)." +} +if ($num_phones_out != 48 && $num_phones_out != 39) { + die "Can only used 48 or 39 for -to (used $num_phones_out)." +} +unless ($num_phones_out < $num_phones_in) { + die "Argument to -from ($num_phones_in) must be greater than that to -to ($num_phones_out)." +} + + +open(M, "<$phone_map") or die "Cannot open mappings file '$phone_map': $!"; +my (%phonemap, %seen_phones); +my $num_seen_phones = 0; +while () { + chomp; + next if ($_ =~ /^q\s*.*$/); # Ignore glottal stops. + m:^(\S+)\s+(\S+)\s+(\S+)$: or die "Bad line: $_"; + my $mapped_from = ($num_phones_in == 60)? $1 : $2; + my $mapped_to = ($num_phones_out == 48)? $2 : $3; + if (!defined($seen_phones{$mapped_to})) { + $seen_phones{$mapped_to} = 1; + $num_seen_phones += 1; + } + $phonemap{$mapped_from} = $mapped_to; +} +if ($num_seen_phones != $num_phones_out) { + die "Trying to map to $num_phones_out phones, but seen only $num_seen_phones"; +} + +open(T, "<$in_trans") or die "Cannot open transcription file '$in_trans': $!"; +while () { + chomp; + $_ =~ m:^(\S+)\s+(.+): or die "Bad line: $_"; + my $utt_id = $1; + my $trans = $2; + + $trans =~ s/q//g; # Remove glottal stops. + $trans =~ s/^\s*//; $trans =~ s/\s*$//; # Normalize spaces + + print $utt_id; + for my $phone (split(/\s+/, $trans)) { + if(exists $phonemap{$phone}) { print " $phonemap{$phone}"; } + if(not exists $phonemap{$phone}) { print " $phone"; } + } + print "\n"; +} + diff --git a/egs/timit/asr/simple_v1/local2/timit_prepare_dict.sh b/egs/timit/asr/simple_v1/local2/timit_prepare_dict.sh new file mode 100644 index 00000000..5540ee13 --- /dev/null +++ b/egs/timit/asr/simple_v1/local2/timit_prepare_dict.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash + +# Copyright 2013 (Authors: Daniel Povey, Bagher BabaAli) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# Call this script from one level above, e.g. from the s3/ directory. It puts +# its output in data/local/. + +# The parts of the output of this that will be needed are +# [in data/local/dict/ ] +# lexicon.txt +# extra_questions.txt +# nonsilence_phones.txt +# optional_silence.txt +# silence_phones.txt + +# run this from ../ +srcdir=data/local/data +dir=data/local/dict + +mkdir -p $dir + +[ -f path.sh ] && . ./path.sh + +#(1) Dictionary preparation: + +# Make phones symbol-table (adding in silence and verbal and non-verbal noises at this point). +# We are adding suffixes _B, _E, _S for beginning, ending, and singleton phones. + +# silence phones, one per line. +echo sil > $dir/silence_phones.txt +echo sil > $dir/optional_silence.txt + +# nonsilence phones; on each line is a list of phones that correspond +# really to the same base phone. + +# Create the lexicon, which is just an identity mapping +cut -d' ' -f2- $srcdir/train.text | tr ' ' '\n' | sort -u > $dir/phones.txt +echo "" >> $dir/phones.txt +paste $dir/phones.txt $dir/phones.txt > $dir/lexicon.txt || exit 1; +grep -v -F -f $dir/silence_phones.txt $dir/phones.txt > $dir/nonsilence_phones.txt + +# A few extra questions that will be added to those obtained by automatically clustering +# the "real" phones. These ask about stress; there's also one for silence. +cat $dir/silence_phones.txt| awk '{printf("%s ", $1);} END{printf "\n";}' > $dir/extra_questions.txt || exit 1; +cat $dir/nonsilence_phones.txt | perl -e 'while(<>){ foreach $p (split(" ", $_)) { + $p =~ m:^([^\d]+)(\d*)$: || die "Bad phone $_"; $q{$2} .= "$p "; } } foreach $l (values %q) {print "$l\n";}' \ + >> $dir/extra_questions.txt || exit 1; + +echo "Dict prepatation succeeded" diff --git a/egs/timit/asr/simple_v1/local2/train_lms.sh b/egs/timit/asr/simple_v1/local2/train_lms.sh new file mode 100644 index 00000000..0d0be560 --- /dev/null +++ b/egs/timit/asr/simple_v1/local2/train_lms.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash + +# To be run from one directory above this script. +. ./path.sh + +text=data/local/data/train.text +lexicon=data/local/dict/lexicon.txt + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1 +done + +# This script takes no arguments. It assumes you have already run +# aishell_data_prep.sh. +# It takes as input the files +# data/local/train/text +# data/local/dict/lexicon.txt +dir=data/local/lm +mkdir -p $dir + +kaldi_lm=$(which train_lm.sh) +if [ -z $kaldi_lm ]; then + echo "$0: train_lm.sh is not found. That might mean it's not installed" + echo "$0: or it is not added to PATH" + echo "$0: Please use the following commands to install it" + echo " git clone https://github.com/danpovey/kaldi_lm.git" + echo " cd kaldi_lm" + echo " make -j" + echo "Then add the path of kaldi_lm to PATH and rerun $0" + exit 1 +fi + +cleantext=$dir/text.no_oov + +cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } + {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ + >$cleantext || exit 1 + +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | + sort -nr >$dir/word.counts || exit 1 + +# Get counts from acoustic training transcripts, and add one-count +# for each word in the lexicon (but not silence, we don't want it +# in the LM-- we'll add it optionally later). +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | + cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | + sort | uniq -c | sort -nr >$dir/unigram.counts || exit 1 + +# note: we probably won't really make use of as there aren't any OOVs +cat $dir/unigram.counts | awk '{print $2}' | get_word_map.pl "" "" "" >$dir/word_map || + exit 1 + +# note: ignore 1st field of train.txt, it's the utterance-id. +cat $cleantext | awk -v wmap=$dir/word_map 'BEGIN{while((getline0)map[$1]=$2;} + { for(n=2;n<=NF;n++) { printf map[$n]; if(n$dir/train.gz || + exit 1 + +train_lm.sh --arpa --lmtype 3gram-mincount $dir || exit 1 + +# LM is small enough that we don't need to prune it (only about 0.7M N-grams). +# Perplexity over 128254.000000 words is 90.446690 + +# note: output is +# data/local/lm/3gram-mincount/lm_unpruned.gz + +exit 0 + +# From here is some commands to do a baseline with SRILM (assuming +# you have it installed). +heldout_sent=10000 # Don't change this if you want result to be comparable with +# kaldi_lm results +sdir=$dir/srilm # in case we want to use SRILM to double-check perplexities. +mkdir -p $sdir +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n$sdir/heldout +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n$sdir/train + +cat $dir/word_map | awk '{print $1}' | cat - <( + echo "" + echo "" +) >$sdir/wordlist + +ngram-count -text $sdir/train -order 3 -limit-vocab -vocab $sdir/wordlist -unk \ + -map-unk "" -kndiscount -interpolate -lm $sdir/srilm.o3g.kn.gz +ngram -lm $sdir/srilm.o3g.kn.gz -ppl $sdir/heldout +# 0 zeroprobs, logprob= -250954 ppl= 90.5091 ppl1= 132.482 + +# Note: perplexity SRILM gives to Kaldi-LM model is same as kaldi-lm reports above. +# Difference in WSJ must have been due to different treatment of . +ngram -lm $dir/3gram-mincount/lm_unpruned.gz -ppl $sdir/heldout +# 0 zeroprobs, logprob= -250913 ppl= 90.4439 ppl1= 132.379 diff --git a/egs/timit/asr/simple_v1/path.sh b/egs/timit/asr/simple_v1/path.sh new file mode 100644 index 00000000..4e9d61ef --- /dev/null +++ b/egs/timit/asr/simple_v1/path.sh @@ -0,0 +1,3 @@ +export SNOWFALL_ROOT=`pwd`/../../../.. +[ -f $SNOWFALL_ROOT/tools/env.sh ] && . $SNOWFALL_ROOT/tools/env.sh +export LC_ALL=C \ No newline at end of file diff --git a/egs/timit/asr/simple_v1/prepare.py b/egs/timit/asr/simple_v1/prepare.py new file mode 100644 index 00000000..310d4b08 --- /dev/null +++ b/egs/timit/asr/simple_v1/prepare.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Junbo Zhang, Haowen Qiu) +# Copyright (c) 2021 Xiaomi Corporation (authors: Mingshuang Luo) +# Apache 2.0 + +import argparse +import os +import subprocess +import sys +from contextlib import contextmanager +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine +from lhotse.recipes import download_timit, prepare_timit, prepare_musan + +from snowfall.common import str2bool + +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +num_jobs = min(16, os.cpu_count()) + +@contextmanager +def get_executor(): + # We'll either return a process pool or a distributed worker pool. + # Note that this has to be a context manager because we might use multiple + # context manager ("with" clauses) inside, and this way everything will + # free up the resources at the right time. + try: + # If this is executed on the CLSP grid, we will try to use the + # Grid Engine to distribute the tasks. + # Other clusters can also benefit from that, provided a cluster-specific wrapper. + # (see https://github.com/pzelasko/plz for reference) + # + # The following must be installed: + # $ pip install dask distributed + # $ pip install git+https://github.com/pzelasko/plz + name = subprocess.check_output('hostname -f', shell=True, text=True) + if name.strip().endswith('.clsp.jhu.edu'): + import plz + from distributed import Client + with plz.setup_cluster() as cluster: + cluster.scale(80) + yield Client(cluster) + return + except: + pass + # No need to return anything - compute_and_store_features + # will just instantiate the pool itself. + yield None + + +def locate_corpus(*corpus_dirs): + for d in corpus_dirs: + if os.path.exists(d): + return d + print("Please create a place on your system to put the downloaded timit data " + "and add it to `corpus_dirs`") + sys.exit(1) + +def get_parser(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--num-jobs', + type=int, + default=min(15, os.cpu_count()), + help='the value of jobs is bigger, it processes the data more quickly.') + + parser.add_argument( + '--num-phones', + type=int, + default=48, + help='the number of phones for modeling, it must be in [60, 48, 39].') + + return parser + +def main(): + args = get_parser().parse_args() + corpus_dir = locate_corpus( + Path('/ceph-meixu/luomingshuang/audio-data/timit'), + Path('/export/common/data/corpora/timit'), + Path('/ceph-fj/data/timit'), + ) + + musan_dir = locate_corpus( + Path('/ceph-meixu/luomingshuang/audio-data/musan'), + Path('/export/common/data/corpora/MUSAN/musan'), + Path('/ceph-fj/data/musan'), + ) + + output_dir = Path('exp/data') + splits_dir = Path('splits_dir') + + print('Timit manifest preparation:') + timit_manifests = prepare_timit( + corpus_dir = corpus_dir, + splits_dir = splits_dir, + output_dir = output_dir, + num_phones = args.num_phones, + num_jobs = args.num_jobs) + + print('Musan manifest preparation:') + musan_cuts_path = output_dir / 'cuts_musan.json.gz' + musan_manifests = prepare_musan( + corpus_dir=musan_dir, + output_dir=output_dir, + parts=('music', 'speech', 'noise') + ) + + print('Feature extraction:') + extractor = Fbank(FbankConfig(num_mel_bins=80)) + with get_executor() as ex: # Initialize the executor only once. + for partition, manifests in timit_manifests.items(): + if (output_dir / f'cuts_{partition}.json.gz').is_file(): + print(f'{partition} already exists - skipping.') + continue + print('Processing', partition) + cut_set = CutSet.from_manifests( + recordings=manifests['recordings'], + supervisions=manifests['supervisions'] + ) + if 'train' in partition: + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f'{output_dir}/feats_{partition}', + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomHdf5Writer + ) + timit_manifests[partition]['cuts'] = cut_set + cut_set.to_json(output_dir / f'cuts_{partition}.json.gz') + # Now onto Musan + if not musan_cuts_path.is_file(): + print('Extracting features for Musan') + # create chunks of Musan with duration 5 - 10 seconds + musan_cuts = CutSet.from_manifests( + recordings=combine(part['recordings'] for part in musan_manifests.values()) + ).cut_into_windows(10.0).filter(lambda c: c.duration > 5).compute_and_store_features( + extractor=extractor, + storage_path=f'{output_dir}/feats_musan', + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomHdf5Writer + ) + musan_cuts.to_json(musan_cuts_path) + + +if __name__ == '__main__': + main() diff --git a/egs/timit/asr/simple_v1/run.sh b/egs/timit/asr/simple_v1/run.sh new file mode 100644 index 00000000..728eb4b0 --- /dev/null +++ b/egs/timit/asr/simple_v1/run.sh @@ -0,0 +1,105 @@ +#!/usr/bin/env bash + +# Copyright 2020 Xiaomi Corporation (Author: Junbo Zhang +# Mingshuang Luo) +# 2021 Pingfeng Luo +# Apache 2.0 + +# Example of how to build L and G FST for K2. Most scripts of this example are copied from Kaldi. + +set -eou pipefail + +dataset_path=( + /ceph-meixu/luomingshuang/audio-data/timit + ) + +data=${dataset_path[0]} +for d in ${dataset_path[@]}; do + if [ -d $d ]; then + data=$d + break + fi +done + +if [ ! -d $data ]; then + echo "$data does not exist" + exit 1 +fi + +[ -f path.sh ] && . ./path.sh + +stage=1 + +if [ $stage -le 1 ]; then + echo "Data preparation" + local2/timit_data_prep.sh $data +fi + + +if [ $stage -le 2 ]; then + echo "Dict preparation" + local2/timit_prepare_dict.sh +fi + + +if [ $stage -le 3 ]; then + echo "Lang preparation" + local/prepare_lang.sh \ + --sil-prob 0.0 \ + --position-dependent-phones false \ + --num-sil-states 3 \ + data/local/dict \ + "sil" \ + data/local/lang_tmp_nosp \ + data/lang_nosp + + echo "To load L:" + echo "Use::" + echo " with open('data/lang_nosp/L.fst.txt') as f:" + echo " Lfst = k2.Fsa.from_openfst(f.read(), acceptor=False)" + echo "" +fi + + +if [ $stage -le 4 ]; then + echo "LM preparation" + local2/train_lms.sh + gunzip -c data/local/lm/3gram-mincount/lm_unpruned.gz >data/local/lm/lm_tgmed.arpa + # Note: you need to install kaldilm using `pip install kaldilm` + # Build G + python3 -m kaldilm \ + --read-symbol-table="data/lang_nosp/words.txt" \ + --disambig-symbol='#0' \ + --max-order=1 \ + data/local/lm/lm_tgmed.arpa >data/lang_nosp/G_uni.fst.txt + + python3 -m kaldilm \ + --read-symbol-table="data/lang_nosp/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/local/lm/lm_tgmed.arpa >data/lang_nosp/G.fst.txt + + echo "" + echo "To load G:" + echo "Use::" + echo " with open('data/lang_nosp/G.fst.txt') as f:" + echo " G = k2.Fsa.from_openfst(f.read(), acceptor=False)" + echo "" +fi + + +if [ $stage -le 5 ]; then + echo "Feature preparation" + python3 ./prepare.py +fi + +if [ $stage -le 6 ]; then + echo "Training" + python3 ./ctc_train.py +fi + +if [ $stage -le 7 ]; then + echo "Decoding" + python3 ./ctc_decode.py +fi + From ca0a7fb2c16fc50162e0430c897c19ec2952af24 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Fri, 3 Sep 2021 11:27:38 +0800 Subject: [PATCH 03/10] Delete mmi_bigram_train.py --- egs/aishell/asr/simple_v1/mmi_bigram_train.py | 513 ------------------ 1 file changed, 513 deletions(-) delete mode 100644 egs/aishell/asr/simple_v1/mmi_bigram_train.py diff --git a/egs/aishell/asr/simple_v1/mmi_bigram_train.py b/egs/aishell/asr/simple_v1/mmi_bigram_train.py deleted file mode 100644 index c24b98a5..00000000 --- a/egs/aishell/asr/simple_v1/mmi_bigram_train.py +++ /dev/null @@ -1,513 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu, Mingshuang Luo) -# 2021 Pingfeng Luo -# Apache 2.0 - -import k2 -import logging -import math -import numpy as np -import os -import sys -import torch -import torch.optim as optim -from datetime import datetime -from pathlib import Path -from torch import nn -from torch.nn.utils import clip_grad_value_, clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from typing import Dict, Optional, Tuple, List - -from lhotse import CutSet -from lhotse.dataset import CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler -from lhotse.utils import fix_random_seed, nullcontext -from snowfall.common import describe -from snowfall.common import find_first_disambig_symbol -from snowfall.common import load_checkpoint, save_checkpoint, str2bool -from snowfall.common import save_training_info -from snowfall.common import setup_logger -from snowfall.dist import cleanup_dist, setup_dist -from snowfall.lexicon import Lexicon -from snowfall.models import AcousticModel -from snowfall.models.tdnn_lstm import TdnnLstm1b -from snowfall.objectives.mmi import LFMMILoss -from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change -from snowfall.training.mmi_graph import MmiTrainingGraphCompiler -from snowfall.training.mmi_graph import create_bigram_phone_lm -from snowfall.training.mmi_graph import get_phone_symbols - -den_scale = 1.0 - -def encode_supervisions(supervisions: Dict[str, torch.Tensor], - subsampling_factor) -> Tuple[torch.Tensor, List[str]]: - """ - Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, - and a list of transcription strings. - The supervision tensor has shape ``(batch_size, 3)``. - Its second dimension contains information about sequence index [0], - start frames [1] and num frames [2]. - The batch items might become re-ordered during this operation -- the returned tensor - and list of strings are guaranteed to be consistent with each other. - This mimics subsampling by a factor of 4 with Conv1D layer with no padding. - """ - supervision_segments = torch.stack( - (supervisions['sequence_idx'], - torch.floor_divide(supervisions['start_frame'], - subsampling_factor), - torch.floor_divide(supervisions['num_frames'], - subsampling_factor)), 1).to(torch.int32) - supervision_segments = torch.clamp(supervision_segments, min=0) - indices = torch.argsort(supervision_segments[:, 2], descending=True) - supervision_segments = supervision_segments[indices] - texts = supervisions['text'] - texts = [texts[idx] for idx in indices] - return supervision_segments, texts - -def get_objf(batch: Dict, - model: AcousticModel, - device: torch.device, - graph_compiler: MmiTrainingGraphCompiler, - is_training: bool, - tb_writer: Optional[SummaryWriter] = None, - global_batch_idx_train: Optional[int] = None, - optimizer: Optional[torch.optim.Optimizer] = None): - feature = batch['inputs'] - supervisions = batch['supervisions'] - supervision_segments = torch.stack( - (supervisions['sequence_idx'], - torch.floor_divide(supervisions['start_frame'], - model.subsampling_factor), - torch.floor_divide(supervisions['num_frames'], - model.subsampling_factor)), 1).to(torch.int32) - indices = torch.argsort(supervision_segments[:, 2], descending=True) - supervision_segments = supervision_segments[indices] - - texts = supervisions['text'] - texts = [texts[idx] for idx in indices] - assert feature.ndim == 3 - # print(supervision_segments[:, 1] + supervision_segments[:, 2]) - - feature = feature.to(device) - # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - assert feature.ndim == 3 - feature = feature.to(device) - - try: - subsampling_factor = model.subsampling_factor - except: - subsampling_factor = model.module.subsampling_factor - - supervisions = batch['supervisions'] - supervision_segments, texts = encode_supervisions(supervisions, - subsampling_factor) - - loss_fn = LFMMILoss( - graph_compiler=graph_compiler, - den_scale=den_scale - ) - - grad_context = nullcontext if is_training else torch.no_grad - - with grad_context(): - nnet_output = model(feature) - # nnet_output is [N, C, T] - nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] - mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) - - if is_training: - def maybe_log_gradients(tag: str): - if ( - tb_writer is not None - and global_batch_idx_train is not None - and global_batch_idx_train % 200 == 0 - ): - tb_writer.add_scalars( - tag, - measure_gradient_norms(model, norm='l1'), - global_step=global_batch_idx_train - ) - - optimizer.zero_grad() - (-mmi_loss).backward() - - maybe_log_gradients('train/grad_norms') - #clip_grad_value_(model.parameters(), 5.0) - clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) - maybe_log_gradients('train/clipped_grad_norms') - if tb_writer is not None and global_batch_idx_train % 200 == 0: - # Once in a time we will perform a more costly diagnostic - # to check the relative parameter change per minibatch. - deltas = optim_step_and_measure_param_change(model, optimizer) - tb_writer.add_scalars( - 'train/relative_param_change_per_minibatch', - deltas, - global_step=global_batch_idx_train - ) - else: - optimizer.step() - - ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item() - return ans - - -def get_validation_objf(dataloader: torch.utils.data.DataLoader, - model: AcousticModel, - device: torch.device, - graph_compiler: MmiTrainingGraphCompiler): - total_objf = 0. - total_frames = 0. # for display only - total_all_frames = 0. # all frames including those seqs that failed. - - model.eval() - - for batch_idx, batch in enumerate(dataloader): - objf, frames, all_frames = get_objf(batch, model, device, - graph_compiler, False) - total_objf += objf - total_frames += frames - total_all_frames += all_frames - - return total_objf, total_frames, total_all_frames - - -def train_one_epoch(dataloader: torch.utils.data.DataLoader, - valid_dataloader: torch.utils.data.DataLoader, - model: AcousticModel, - device: torch.device, - graph_compiler: MmiTrainingGraphCompiler, - optimizer: torch.optim.Optimizer, - current_epoch: int, - tb_writer: SummaryWriter, - num_epochs: int, - global_batch_idx_train: int): - total_objf, total_frames, total_all_frames = 0., 0., 0. - valid_average_objf = float('inf') - time_waiting_for_batch = 0 - prev_timestamp = datetime.now() - - model.train() - for batch_idx, batch in enumerate(dataloader): - global_batch_idx_train += 1 - timestamp = datetime.now() - time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() - - curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( - batch=batch, - model=model, - device=device, - graph_compiler=graph_compiler, - is_training=True, - tb_writer=tb_writer, - global_batch_idx_train=global_batch_idx_train, - optimizer=optimizer - ) - - total_objf += curr_batch_objf - total_frames += curr_batch_frames - total_all_frames += curr_batch_all_frames - - if batch_idx % 10 == 0: - logging.info( - 'batch {}, epoch {}/{} ' - 'global average objf: {:.6f} over {} ' - 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' - 'avg time waiting for batch {:.3f}s'.format( - batch_idx, current_epoch, num_epochs, - total_objf / total_frames, total_frames, - 100.0 * total_frames / total_all_frames, - curr_batch_objf / (curr_batch_frames + 0.001), - curr_batch_frames, - 100.0 * curr_batch_frames / curr_batch_all_frames, - time_waiting_for_batch / max(1, batch_idx))) - - print( - 'batch {}, epoch {}/{} ' - 'global average objf: {:.6f} over {} ' - 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' - 'avg time waiting for batch {:.3f}s'.format( - batch_idx, current_epoch, num_epochs, - total_objf / total_frames, total_frames, - 100.0 * total_frames / total_all_frames, - curr_batch_objf / (curr_batch_frames + 0.001), - curr_batch_frames, - 100.0 * curr_batch_frames / curr_batch_all_frames, - time_waiting_for_batch / max(1, batch_idx))) - tb_writer.add_scalar('train/global_average_objf', - total_objf / total_frames, global_batch_idx_train) - - tb_writer.add_scalar('train/current_batch_average_objf', - curr_batch_objf / (curr_batch_frames + 0.001), - global_batch_idx_train) - # if batch_idx >= 10: - # print("Exiting early to get profile info") - # sys.exit(0) - - if batch_idx > 0 and batch_idx % 200 == 0: - total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( - dataloader=valid_dataloader, - model=model, - device=device, - graph_compiler=graph_compiler) - valid_average_objf = total_valid_objf / total_valid_frames - model.train() - logging.info( - 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' - .format(valid_average_objf, - total_valid_frames, - 100.0 * total_valid_frames / total_valid_all_frames)) - - tb_writer.add_scalar('train/global_valid_average_objf', - valid_average_objf, - global_batch_idx_train) - model.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) - prev_timestamp = datetime.now() - return total_objf / total_frames, valid_average_objf, global_batch_idx_train - -def get_parser(): - import argparse - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--world_size', default=1, type=int) - parser.add_argument('--local_rank', default=0, type=int) - parser.add_argument('--master_port', default=str(12345), type=str) - parser.add_argument('--bucketing_sampler', type=str2bool, default=True) - - return parser - -def main(): - args = get_parser().parse_args() - print('World size:', args.world_size, 'Rank:', args.local_rank) - setup_dist(rank=args.local_rank, world_size=args.world_size, master_port=args.master_port) - fix_random_seed(42) - - start_epoch = 0 - num_epochs = 10 - use_adam = True - - exp_dir = f'exp-lstm-adam-mmi-bigram-musan' - setup_logger('{}/log/log-train'.format(exp_dir)) - tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') - - # load L, G, symbol_table - lang_dir = Path('data/lang_nosp') - lexicon = Lexicon(lang_dir) - - device_id = args.local_rank - device = torch.device('cuda', device_id) - phone_ids = lexicon.phone_symbols() - - if not Path(lang_dir / 'P.pt').is_file(): - logging.debug(f'Loading P from {lang_dir}/P.fst.txt') - with open(lang_dir / 'P.fst.txt') as f: - # P is not an acceptor because there is - # a back-off state, whose incoming arcs - # have label #0 and aux_label eps. - P = k2.Fsa.from_openfst(f.read(), acceptor=False) - - phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') - first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) - - # P.aux_labels is not needed in later computations, so - # remove it here. - del P.aux_labels - # CAUTION(fangjun): The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - P.labels[P.labels >= first_phone_disambig_id] = 0 - - P = k2.remove_epsilon(P) - P = k2.arc_sort(P) - torch.save(P.as_dict(), lang_dir / 'P.pt') - else: - logging.debug('Loading pre-compiled P') - d = torch.load(lang_dir / 'P.pt') - P = k2.Fsa.from_dict(d) - - graph_compiler = MmiTrainingGraphCompiler( - lexicon=lexicon, - P=P, - device=device, - ) - - # load dataset - feature_dir = Path('exp/data') - logging.info("About to get train cuts") - cuts_train = CutSet.from_json(feature_dir / - 'cuts_train.json.gz') - logging.info("About to get dev cuts") - cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz') - logging.info("About to get Musan cuts") - cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cuts_train, - cut_transforms=[ - CutConcatenate(), - CutMix( - cuts=cuts_musan, - prob=0.5, - snr=(10, 20) - ) - ] - ) - train_sampler = SingleCutSampler( - cuts_train, - max_frames=40000, - shuffle=True, - ) - logging.info("About to create train dataloader") - train_dl = torch.utils.data.DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=4 - ) - logging.info("About to create dev dataset") - validate = K2SpeechRecognitionDataset(cuts_dev) - valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000) - logging.info("About to create dev dataloader") - valid_dl = torch.utils.data.DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=1 - ) - - if not torch.cuda.is_available(): - logging.error('No GPU detected!') - sys.exit(-1) - - logging.info("About to create model") - device_id = 0 - device = torch.device('cuda', device_id) - model = TdnnLstm1b(num_features=40, - num_classes=len(phone_ids) + 1, # +1 for the blank symbol - subsampling_factor=3) - model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) - - model.to(device) - describe(model) - - if use_adam: - learning_rate = 1e-3 - weight_decay = 5e-4 - optimizer = optim.AdamW(model.parameters(), - lr=learning_rate, - weight_decay=weight_decay) - # Equivalent to the following in the epoch loop: - # if epoch > 6: - # curr_learning_rate *= 0.8 - lr_scheduler = optim.lr_scheduler.LambdaLR( - optimizer, - lambda ep: 1.0 if ep < 7 else 0.8 ** (ep - 6) - ) - else: - learning_rate = 5e-5 - weight_decay = 1e-5 - momentum = 0.9 - lr_schedule_gamma = 0.7 - optimizer = optim.SGD( - model.parameters(), - lr=learning_rate, - momentum=momentum, - weight_decay=weight_decay - ) - lr_scheduler = optim.lr_scheduler.ExponentialLR( - optimizer=optimizer, - gamma=lr_schedule_gamma - ) - - best_objf = np.inf - best_valid_objf = np.inf - best_epoch = start_epoch - best_model_path = os.path.join(exp_dir, 'best_model.pt') - best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') - global_batch_idx_train = 0 # for logging only - - if start_epoch > 0: - model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) - ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) - best_objf = ckpt['objf'] - best_valid_objf = ckpt['valid_objf'] - global_batch_idx_train = ckpt['global_batch_idx_train'] - logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") - - for epoch in range(start_epoch, num_epochs): - train_sampler.set_epoch(epoch) - # LR scheduler can hold multiple learning rates for multiple parameter groups; - # For now we report just the first LR which we assume concerns most of the parameters. - curr_learning_rate = lr_scheduler.get_last_lr()[0] - tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) - tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) - - logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate)) - objf, valid_objf, global_batch_idx_train = train_one_epoch( - dataloader=train_dl, - valid_dataloader=valid_dl, - model=model, - device=device, - graph_compiler=graph_compiler, - optimizer=optimizer, - current_epoch=epoch, - tb_writer=tb_writer, - num_epochs=num_epochs, - global_batch_idx_train=global_batch_idx_train, - ) - - lr_scheduler.step() - - # the lower, the better - if valid_objf < best_valid_objf: - best_valid_objf = valid_objf - best_objf = objf - best_epoch = epoch - save_checkpoint(filename=best_model_path, - model=model, - optimizer=None, - scheduler=None, - epoch=epoch, - learning_rate=curr_learning_rate, - objf=objf, - valid_objf=valid_objf, - global_batch_idx_train=global_batch_idx_train) - save_training_info(filename=best_epoch_info_filename, - model_path=best_model_path, - current_epoch=epoch, - learning_rate=curr_learning_rate, - objf=objf, - best_objf=best_objf, - valid_objf=valid_objf, - best_valid_objf=best_valid_objf, - best_epoch=best_epoch) - - # we always save the model for every epoch - model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) - save_checkpoint(filename=model_path, - model=model, - optimizer=optimizer, - scheduler=lr_scheduler, - epoch=epoch, - learning_rate=curr_learning_rate, - objf=objf, - valid_objf=valid_objf, - global_batch_idx_train=global_batch_idx_train) - epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) - save_training_info(filename=epoch_info_filename, - model_path=model_path, - current_epoch=epoch, - learning_rate=curr_learning_rate, - objf=objf, - best_objf=best_objf, - valid_objf=valid_objf, - best_valid_objf=best_valid_objf, - best_epoch=best_epoch) - - logging.warning('Done') - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == '__main__': - main() From 4f655e3b7e2bbf3b5da54b3cdb6d4e2d2c2f3c03 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Fri, 3 Sep 2021 11:35:03 +0800 Subject: [PATCH 04/10] Create mmi_bigram_train.py --- egs/aishell/asr/simple_v1/mmi_bigram_train.py | 518 ++++++++++++++++++ 1 file changed, 518 insertions(+) create mode 100644 egs/aishell/asr/simple_v1/mmi_bigram_train.py diff --git a/egs/aishell/asr/simple_v1/mmi_bigram_train.py b/egs/aishell/asr/simple_v1/mmi_bigram_train.py new file mode 100644 index 00000000..b5371018 --- /dev/null +++ b/egs/aishell/asr/simple_v1/mmi_bigram_train.py @@ -0,0 +1,518 @@ + +#!/usr/bin/env python3 +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu, Mingshuang Luo) +# 2021 Pingfeng Luo +# Apache 2.0 + +import k2 +import logging +import math +import numpy as np +import os +import sys +import torch +import torch.optim as optim +from datetime import datetime +from pathlib import Path +from torch import nn +from torch.nn.utils import clip_grad_value_, clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, Optional, Tuple, List + +from lhotse import CutSet +from lhotse.dataset import CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.utils import fix_random_seed, nullcontext +from snowfall.common import describe +from snowfall.common import find_first_disambig_symbol +from snowfall.common import load_checkpoint, save_checkpoint, str2bool +from snowfall.common import save_training_info +from snowfall.common import setup_logger +from snowfall.dist import cleanup_dist, setup_dist +from snowfall.lexicon import Lexicon +from snowfall.models import AcousticModel +from snowfall.models.tdnn_lstm import TdnnLstm1b +from snowfall.objectives.mmi import LFMMILoss +from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change +from snowfall.training.mmi_graph import MmiTrainingGraphCompiler +from snowfall.training.mmi_graph import create_bigram_phone_lm +from snowfall.training.mmi_graph import get_phone_symbols + +den_scale = 1.0 + +def encode_supervisions(supervisions: Dict[str, torch.Tensor], + subsampling_factor) -> Tuple[torch.Tensor, List[str]]: + """ + Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, + and a list of transcription strings. + The supervision tensor has shape ``(batch_size, 3)``. + Its second dimension contains information about sequence index [0], + start frames [1] and num frames [2]. + The batch items might become re-ordered during this operation -- the returned tensor + and list of strings are guaranteed to be consistent with each other. + This mimics subsampling by a factor of 4 with Conv1D layer with no padding. + """ + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + torch.floor_divide(supervisions['start_frame'], + subsampling_factor), + torch.floor_divide(supervisions['num_frames'], + subsampling_factor)), 1).to(torch.int32) + supervision_segments = torch.clamp(supervision_segments, min=0) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + texts = supervisions['text'] + texts = [texts[idx] for idx in indices] + return supervision_segments, texts + +def get_objf(batch: Dict, + model: AcousticModel, + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + is_training: bool, + tb_writer: Optional[SummaryWriter] = None, + global_batch_idx_train: Optional[int] = None, + optimizer: Optional[torch.optim.Optimizer] = None): + feature = batch['inputs'] + supervisions = batch['supervisions'] + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + torch.floor_divide(supervisions['start_frame'], + model.subsampling_factor), + torch.floor_divide(supervisions['num_frames'], + model.subsampling_factor)), 1).to(torch.int32) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + + texts = supervisions['text'] + texts = [texts[idx] for idx in indices] + assert feature.ndim == 3 + # print(supervision_segments[:, 1] + supervision_segments[:, 2]) + + feature = feature.to(device) + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + assert feature.ndim == 3 + feature = feature.to(device) + + try: + subsampling_factor = model.subsampling_factor + except: + subsampling_factor = model.module.subsampling_factor + + supervisions = batch['supervisions'] + supervision_segments, texts = encode_supervisions(supervisions, + subsampling_factor) + + loss_fn = LFMMILoss( + graph_compiler=graph_compiler, + den_scale=den_scale + ) + + grad_context = nullcontext if is_training else torch.no_grad + + with grad_context(): + nnet_output = model(feature) + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] + mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) + + if is_training: + def maybe_log_gradients(tag: str): + if ( + tb_writer is not None + and global_batch_idx_train is not None + and global_batch_idx_train % 200 == 0 + ): + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm='l1'), + global_step=global_batch_idx_train + ) + + optimizer.zero_grad() + (-mmi_loss).backward() + + for name, param in model.named_parameters(): + if param.grad is None: + print(name) + + maybe_log_gradients('train/grad_norms') + #clip_grad_value_(model.parameters(), 5.0) + clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) + maybe_log_gradients('train/clipped_grad_norms') + if tb_writer is not None and global_batch_idx_train % 200 == 0: + # Once in a time we will perform a more costly diagnostic + # to check the relative parameter change per minibatch. + deltas = optim_step_and_measure_param_change(model, optimizer) + tb_writer.add_scalars( + 'train/relative_param_change_per_minibatch', + deltas, + global_step=global_batch_idx_train + ) + else: + optimizer.step() + + ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item() + return ans + + +def get_validation_objf(dataloader: torch.utils.data.DataLoader, + model: AcousticModel, + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler): + total_objf = 0. + total_frames = 0. # for display only + total_all_frames = 0. # all frames including those seqs that failed. + + model.eval() + + for batch_idx, batch in enumerate(dataloader): + objf, frames, all_frames = get_objf(batch, model, device, + graph_compiler, False) + total_objf += objf + total_frames += frames + total_all_frames += all_frames + + return total_objf, total_frames, total_all_frames + + +def train_one_epoch(dataloader: torch.utils.data.DataLoader, + valid_dataloader: torch.utils.data.DataLoader, + model: AcousticModel, + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + optimizer: torch.optim.Optimizer, + current_epoch: int, + tb_writer: SummaryWriter, + num_epochs: int, + global_batch_idx_train: int): + total_objf, total_frames, total_all_frames = 0., 0., 0. + valid_average_objf = float('inf') + time_waiting_for_batch = 0 + prev_timestamp = datetime.now() + + model.train() + for batch_idx, batch in enumerate(dataloader): + global_batch_idx_train += 1 + timestamp = datetime.now() + time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() + + curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( + batch=batch, + model=model, + device=device, + graph_compiler=graph_compiler, + is_training=True, + tb_writer=tb_writer, + global_batch_idx_train=global_batch_idx_train, + optimizer=optimizer + ) + + total_objf += curr_batch_objf + total_frames += curr_batch_frames + total_all_frames += curr_batch_all_frames + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, epoch {}/{} ' + 'global average objf: {:.6f} over {} ' + 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' + 'avg time waiting for batch {:.3f}s'.format( + batch_idx, current_epoch, num_epochs, + total_objf / total_frames, total_frames, + 100.0 * total_frames / total_all_frames, + curr_batch_objf / (curr_batch_frames + 0.001), + curr_batch_frames, + 100.0 * curr_batch_frames / curr_batch_all_frames, + time_waiting_for_batch / max(1, batch_idx))) + + print( + 'batch {}, epoch {}/{} ' + 'global average objf: {:.6f} over {} ' + 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' + 'avg time waiting for batch {:.3f}s'.format( + batch_idx, current_epoch, num_epochs, + total_objf / total_frames, total_frames, + 100.0 * total_frames / total_all_frames, + curr_batch_objf / (curr_batch_frames + 0.001), + curr_batch_frames, + 100.0 * curr_batch_frames / curr_batch_all_frames, + time_waiting_for_batch / max(1, batch_idx))) + tb_writer.add_scalar('train/global_average_objf', + total_objf / total_frames, global_batch_idx_train) + + tb_writer.add_scalar('train/current_batch_average_objf', + curr_batch_objf / (curr_batch_frames + 0.001), + global_batch_idx_train) + # if batch_idx >= 10: + # print("Exiting early to get profile info") + # sys.exit(0) + + if batch_idx > 0 and batch_idx % 200 == 0: + total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( + dataloader=valid_dataloader, + model=model, + device=device, + graph_compiler=graph_compiler) + valid_average_objf = total_valid_objf / total_valid_frames + model.train() + logging.info( + 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' + .format(valid_average_objf, + total_valid_frames, + 100.0 * total_valid_frames / total_valid_all_frames)) + + tb_writer.add_scalar('train/global_valid_average_objf', + valid_average_objf, + global_batch_idx_train) + model.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) + prev_timestamp = datetime.now() + return total_objf / total_frames, valid_average_objf, global_batch_idx_train + +def get_parser(): + import argparse + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--world_size', default=1, type=int) + parser.add_argument('--local_rank', default=0, type=int) + parser.add_argument('--master_port', default=str(12345), type=str) + parser.add_argument('--bucketing_sampler', type=str2bool, default=True) + + return parser + +def main(): + args = get_parser().parse_args() + print('World size:', args.world_size, 'Rank:', args.local_rank) + setup_dist(rank=args.local_rank, world_size=args.world_size, master_port=args.master_port) + fix_random_seed(42) + + start_epoch = 0 + num_epochs = 10 + use_adam = True + + exp_dir = f'exp-lstm-adam-mmi-bigram-musan' + setup_logger('{}/log/log-train'.format(exp_dir)) + tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') + + # load L, G, symbol_table + lang_dir = Path('data/lang_nosp') + lexicon = Lexicon(lang_dir) + + device_id = args.local_rank + device = torch.device('cuda', device_id) + phone_ids = lexicon.phone_symbols() + + if not Path(lang_dir / 'P.pt').is_file(): + logging.debug(f'Loading P from {lang_dir}/P.fst.txt') + with open(lang_dir / 'P.fst.txt') as f: + # P is not an acceptor because there is + # a back-off state, whose incoming arcs + # have label #0 and aux_label eps. + P = k2.Fsa.from_openfst(f.read(), acceptor=False) + + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) + + # P.aux_labels is not needed in later computations, so + # remove it here. + del P.aux_labels + # CAUTION(fangjun): The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + P.labels[P.labels >= first_phone_disambig_id] = 0 + + P = k2.remove_epsilon(P) + P = k2.arc_sort(P) + torch.save(P.as_dict(), lang_dir / 'P.pt') + else: + logging.debug('Loading pre-compiled P') + d = torch.load(lang_dir / 'P.pt') + P = k2.Fsa.from_dict(d) + + graph_compiler = MmiTrainingGraphCompiler( + lexicon=lexicon, + P=P, + device=device, + ) + + # load dataset + feature_dir = Path('exp/data') + logging.info("About to get train cuts") + cuts_train = CutSet.from_json(feature_dir / + 'cuts_train.json.gz') + logging.info("About to get dev cuts") + cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz') + logging.info("About to get Musan cuts") + cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cuts_train, + cut_transforms=[ + CutConcatenate(), + CutMix( + cuts=cuts_musan, + prob=0.5, + snr=(10, 20) + ) + ] + ) + train_sampler = SingleCutSampler( + cuts_train, + max_frames=40000, + shuffle=True, + ) + logging.info("About to create train dataloader") + train_dl = torch.utils.data.DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=4 + ) + logging.info("About to create dev dataset") + validate = K2SpeechRecognitionDataset(cuts_dev) + valid_sampler = SingleCutSampler(cuts_dev, max_frames=12000) + logging.info("About to create dev dataloader") + valid_dl = torch.utils.data.DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=1 + ) + + if not torch.cuda.is_available(): + logging.error('No GPU detected!') + sys.exit(-1) + + logging.info("About to create model") + device_id = 0 + device = torch.device('cuda', device_id) + model = TdnnLstm1b(num_features=40, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=3) + model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) + + model.to(device) + describe(model) + + if use_adam: + learning_rate = 1e-3 + weight_decay = 5e-4 + optimizer = optim.AdamW(model.parameters(), + lr=learning_rate, + weight_decay=weight_decay) + # Equivalent to the following in the epoch loop: + # if epoch > 6: + # curr_learning_rate *= 0.8 + lr_scheduler = optim.lr_scheduler.LambdaLR( + optimizer, + lambda ep: 1.0 if ep < 7 else 0.8 ** (ep - 6) + ) + else: + learning_rate = 5e-5 + weight_decay = 1e-5 + momentum = 0.9 + lr_schedule_gamma = 0.7 + optimizer = optim.SGD( + model.parameters(), + lr=learning_rate, + momentum=momentum, + weight_decay=weight_decay + ) + lr_scheduler = optim.lr_scheduler.ExponentialLR( + optimizer=optimizer, + gamma=lr_schedule_gamma + ) + + best_objf = np.inf + best_valid_objf = np.inf + best_epoch = start_epoch + best_model_path = os.path.join(exp_dir, 'best_model.pt') + best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') + global_batch_idx_train = 0 # for logging only + + if start_epoch > 0: + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) + ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scheduler=lr_scheduler) + best_objf = ckpt['objf'] + best_valid_objf = ckpt['valid_objf'] + global_batch_idx_train = ckpt['global_batch_idx_train'] + logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") + + for epoch in range(start_epoch, num_epochs): + train_sampler.set_epoch(epoch) + # LR scheduler can hold multiple learning rates for multiple parameter groups; + # For now we report just the first LR which we assume concerns most of the parameters. + curr_learning_rate = lr_scheduler.get_last_lr()[0] + tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) + tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) + + logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate)) + objf, valid_objf, global_batch_idx_train = train_one_epoch( + dataloader=train_dl, + valid_dataloader=valid_dl, + model=model, + device=device, + graph_compiler=graph_compiler, + optimizer=optimizer, + current_epoch=epoch, + tb_writer=tb_writer, + num_epochs=num_epochs, + global_batch_idx_train=global_batch_idx_train, + ) + + lr_scheduler.step() + + # the lower, the better + if valid_objf < best_valid_objf: + best_valid_objf = valid_objf + best_objf = objf + best_epoch = epoch + save_checkpoint(filename=best_model_path, + model=model, + optimizer=None, + scheduler=None, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train) + save_training_info(filename=best_epoch_info_filename, + model_path=best_model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch) + + # we always save the model for every epoch + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) + save_checkpoint(filename=model_path, + model=model, + optimizer=optimizer, + scheduler=lr_scheduler, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train) + epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) + save_training_info(filename=epoch_info_filename, + model_path=model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch) + + logging.warning('Done') + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() From d892db77618d180cd9bcbf301eaee247d4c93643 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Sun, 12 Sep 2021 15:26:21 +0800 Subject: [PATCH 05/10] Options for the number of phones Choose a number (60, 48, 29) of phones for modeling --- .../asr/simple_v1/local2/timit_data_prep.sh | 19 ++++++++++--------- egs/timit/asr/simple_v1/run.sh | 7 +++++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/egs/timit/asr/simple_v1/local2/timit_data_prep.sh b/egs/timit/asr/simple_v1/local2/timit_data_prep.sh index c7b3e391..c5ca25c7 100644 --- a/egs/timit/asr/simple_v1/local2/timit_data_prep.sh +++ b/egs/timit/asr/simple_v1/local2/timit_data_prep.sh @@ -4,8 +4,9 @@ # 2014 Brno University of Technology (Author: Karel Vesely) # Apache 2.0. -if [ $# -ne 1 ]; then - echo "Argument should be the Timit directory, see ../run.sh for example." +if [ $# -ne 2 ]; then + echo "Usage: local/timit_data_prep.sh " + echo "e.g.: local/timit_data_prep.sh data 48" exit 1; fi @@ -19,7 +20,7 @@ conf=`pwd`/local2 # First check if the train & test directories exist (these can either be upper- # or lower-cased -if [ ! -d $*/TRAIN -o ! -d $*/TEST ] && [ ! -d $*/train -o ! -d $*/test ]; then +if [ ! -d $1/TRAIN -o ! -d $1/TEST ] && [ ! -d $1/train -o ! -d $1/test ]; then echo "timit_data_prep.sh: Spot check of command line argument failed" echo "Command line argument must be absolute pathname to TIMIT directory" echo "with name like /export/corpora5/LDC/LDC93S1/timit/TIMIT" @@ -30,7 +31,7 @@ fi uppercased=false train_dir=train test_dir=test -if [ -d $*/TRAIN ]; then +if [ -d $1/TRAIN ]; then uppercased=true train_dir=TRAIN test_dir=TEST @@ -43,9 +44,9 @@ trap 'rm -rf "$tmpdir"' EXIT # set and the 50-speaker development set must be supplied to the script. All # speakers in the 'train' directory are used for training. if $uppercased; then - ls -d "$*"/TRAIN/DR*/* | sed -e "s:^.*/::" > $tmpdir/train_spk + ls -d "$1"/TRAIN/DR*/* | sed -e "s:^.*/::" > $tmpdir/train_spk else - ls -d "$*"/train/dr*/* | sed -e "s:^.*/::" > $tmpdir/train_spk + ls -d "$1"/train/dr*/* | sed -e "s:^.*/::" > $tmpdir/train_spk fi cd $dir @@ -53,7 +54,7 @@ for x in train; do # First, find the list of audio files (use only si & sx utterances). # Note: train & test sets are under different directories, but doing find on # both and grepping for the speakers will work correctly. - find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.WAV' \ + find $1/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.WAV' \ | grep -f $tmpdir/${x}_spk > ${x}_sph.flist sed -e 's:.*/\(.*\)/\(.*\).\(WAV\|wav\)$:\1_\2:' ${x}_sph.flist \ @@ -66,7 +67,7 @@ for x in train; do # Now, Convert the transcripts into our format (no normalization yet) # Get the transcripts: each line of the output contains an utterance # ID followed by the transcript. - find $*/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.PHN' \ + find $1/{$train_dir,$test_dir} -not \( -iname 'SA*' \) -iname '*.PHN' \ | grep -f $tmpdir/${x}_spk > $tmpdir/${x}_phn.flist sed -e 's:.*/\(.*\)/\(.*\).\(PHN\|phn\)$:\1_\2:' $tmpdir/${x}_phn.flist \ > $tmpdir/${x}_phn.uttids @@ -78,7 +79,7 @@ for x in train; do | sort -k1,1 > ${x}.trans # Do normalization steps. - cat ${x}.trans | $conf/timit_norm_trans.pl -i - -m $conf/phones.60-48-39.map -to 48 | sort > $x.text || exit 1; + cat ${x}.trans | $conf/timit_norm_trans.pl -i - -m $conf/phones.60-48-39.map -to $2 | sort > $x.text || exit 1; done diff --git a/egs/timit/asr/simple_v1/run.sh b/egs/timit/asr/simple_v1/run.sh index 728eb4b0..7e564510 100644 --- a/egs/timit/asr/simple_v1/run.sh +++ b/egs/timit/asr/simple_v1/run.sh @@ -21,6 +21,8 @@ for d in ${dataset_path[@]}; do fi done +num_phones=48 ##Choose a number from {60, 48, 39} + if [ ! -d $data ]; then echo "$data does not exist" exit 1 @@ -32,7 +34,7 @@ stage=1 if [ $stage -le 1 ]; then echo "Data preparation" - local2/timit_data_prep.sh $data + local2/timit_data_prep.sh $data $num_phones fi @@ -90,7 +92,8 @@ fi if [ $stage -le 5 ]; then echo "Feature preparation" - python3 ./prepare.py + python3 ./prepare.py \ + --num-phones=$num_phones fi if [ $stage -le 6 ]; then From 5ee228f9c24236326dac822b02b06caed6d23301 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Sun, 12 Sep 2021 20:38:48 +0800 Subject: [PATCH 06/10] Update train_lms.sh --- egs/timit/asr/simple_v1/local2/train_lms.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/timit/asr/simple_v1/local2/train_lms.sh b/egs/timit/asr/simple_v1/local2/train_lms.sh index 0d0be560..8bf9221d 100644 --- a/egs/timit/asr/simple_v1/local2/train_lms.sh +++ b/egs/timit/asr/simple_v1/local2/train_lms.sh @@ -11,7 +11,7 @@ for f in "$text" "$lexicon"; do done # This script takes no arguments. It assumes you have already run -# aishell_data_prep.sh. +# timit_data_prep.sh. # It takes as input the files # data/local/train/text # data/local/dict/lexicon.txt From 44e7157e2ba23206d3f0a70e65770437c0895d4a Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Sun, 12 Sep 2021 20:49:00 +0800 Subject: [PATCH 07/10] Update run.sh --- egs/timit/asr/simple_v1/run.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/timit/asr/simple_v1/run.sh b/egs/timit/asr/simple_v1/run.sh index 7e564510..ed9da367 100644 --- a/egs/timit/asr/simple_v1/run.sh +++ b/egs/timit/asr/simple_v1/run.sh @@ -28,8 +28,6 @@ if [ ! -d $data ]; then exit 1 fi -[ -f path.sh ] && . ./path.sh - stage=1 if [ $stage -le 1 ]; then From 8bed1236d4f0da587566aa07f5a0fcd9f29bfffd Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Sun, 12 Sep 2021 20:49:21 +0800 Subject: [PATCH 08/10] Delete path.sh --- egs/timit/asr/simple_v1/path.sh | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 egs/timit/asr/simple_v1/path.sh diff --git a/egs/timit/asr/simple_v1/path.sh b/egs/timit/asr/simple_v1/path.sh deleted file mode 100644 index 4e9d61ef..00000000 --- a/egs/timit/asr/simple_v1/path.sh +++ /dev/null @@ -1,3 +0,0 @@ -export SNOWFALL_ROOT=`pwd`/../../../.. -[ -f $SNOWFALL_ROOT/tools/env.sh ] && . $SNOWFALL_ROOT/tools/env.sh -export LC_ALL=C \ No newline at end of file From 07b9bdb2c9bd161feaed835896c0b807d3536926 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Mon, 13 Sep 2021 13:17:50 +0800 Subject: [PATCH 09/10] Add CRDNN model for timit recipe Add CRDNN model for timit recipe. The results get better based on this model than TdnnLstm1b. --- egs/timit/asr/simple_v1/RESULTS.md | 47 +- .../asr/simple_v1/ctc_decode_with_CRDNN.py | 236 +++++++++ .../asr/simple_v1/ctc_train_with_CRDNN.py | 456 ++++++++++++++++++ 3 files changed, 738 insertions(+), 1 deletion(-) create mode 100644 egs/timit/asr/simple_v1/ctc_decode_with_CRDNN.py create mode 100644 egs/timit/asr/simple_v1/ctc_train_with_CRDNN.py diff --git a/egs/timit/asr/simple_v1/RESULTS.md b/egs/timit/asr/simple_v1/RESULTS.md index 95694bc9..066bcf95 100644 --- a/egs/timit/asr/simple_v1/RESULTS.md +++ b/egs/timit/asr/simple_v1/RESULTS.md @@ -3,7 +3,8 @@ ## 2021-09-03 (Mingshuang Luo): -### TIMIT CTC_Train +### TIMIT CTC_Train Based on 48 phones + Testing results based on different training epochs: ``` epoch=20 @@ -18,3 +19,47 @@ epoch=35 epoch=40 2021-09-03 11:12:39,029 INFO [ctc_decode.py:188] %PER 29.52% [2165 / 7333, 304 ins, 348 del, 1513 sub ] ``` + +### TIMIT CTC_Train Based on 39 phones + +Testing results based on different training epochs: +``` +epoch=40 +2021-09-13 11:02:14,793 INFO [ctc_decode.py:189] %PER 25.61% [1848 / 7215, 301 ins, 396 del, 1151 sub ] + +epoch=45 +2021-09-13 11:01:20,787 INFO [ctc_decode.py:189] %PER 25.50% [1840 / 7215, 286 ins, 386 del, 1168 sub ] + +epoch=47 +2021-09-13 11:04:05,533 INFO [ctc_decode.py:189] %PER 26.20% [1890 / 7215, 373 ins, 367 del, 1150 sub ] + +``` +### TIMIT CTC_TRAIN_with_CRDNN Based on 48 phones + +Testing results based on different training epochs: +``` +epoch=35 +2021-09-13 11:21:01,592 INFO [ctc_crdnn_decode.py:201] %PER 20.46% [1476 / 7215, 249 ins, 356 del, 871 sub ] + +epoch=45 +2021-09-13 11:22:02,221 INFO [ctc_crdnn_decode.py:201] %PER 19.75% [1425 / 7215, 239 ins, 324 del, 862 sub ] + +epoch=53 +2021-09-13 11:23:07,969 INFO [ctc_crdnn_decode.py:201] %PER 18.86% [1361 / 7215, 214 ins, 320 del, 827 sub ] + +``` + +### TIMIT CTC_TRAIN_with_CRDNN Based on 39 phones + +Testing results based on different training epochs: +``` +epoch=26 +2021-09-13 11:32:41,388 INFO [ctc_crdnn_decode.py:201] %PER 21.04% [1518 / 7215, 345 ins, 251 del, 922 sub ] + +epoch=45 +2021-09-13 11:34:27,566 INFO [ctc_crdnn_decode.py:201] %PER 18.74% [1352 / 7215, 316 ins, 239 del, 797 sub ] + +epoch=55 +2021-09-13 11:35:55,751 INFO [ctc_crdnn_decode.py:201] %PER 18.24% [1316 / 7215, 267 ins, 242 del, 807 sub ] + +``` \ No newline at end of file diff --git a/egs/timit/asr/simple_v1/ctc_decode_with_CRDNN.py b/egs/timit/asr/simple_v1/ctc_decode_with_CRDNN.py new file mode 100644 index 00000000..2a981f78 --- /dev/null +++ b/egs/timit/asr/simple_v1/ctc_decode_with_CRDNN.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu) +# 2021 Xiaomi Corporation (authors: Mingshuang Luo) +# Apache 2.0 + +# Notice: before you run this script, you should install speechbrain first. +# You can install speechbrain by "pip install speechbrain". + +import k2 +import logging +import os +import torch +import torch.nn as nn + +from k2 import Fsa, SymbolTable +from kaldialign import edit_distance +from pathlib import Path +from typing import Union + +from lhotse import CutSet +from lhotse.dataset import K2SpeechRecognitionDataset +from lhotse.dataset import SingleCutSampler +from snowfall.common import find_first_disambig_symbol +from snowfall.common import get_phone_symbols +from snowfall.common import get_texts +from snowfall.common import load_checkpoint +from snowfall.common import setup_logger +from snowfall.decoding.graph import compile_HLG +from snowfall.models import AcousticModel +from snowfall.training.ctc_graph import build_ctc_topo + +import sys +import argparse + +from speechbrain.lobes.models.CRDNN import CRDNN +from speechbrain.nnet.linear import Linear + +class crdnn_model(nn.Module): + def __init__(self, num_features:int, + num_classes:int, + subsampling_factor: int, + crdnn, + linear): + + super(crdnn_model, self).__init__() + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + + self.crdnn = crdnn + self.linear=linear + + def forward(self, x): + x = self.crdnn(x) + x = self.linear(x) + x = x.transpose(1,2) + x = nn.functional.log_softmax(x, dim=1) + + return x + +def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, + device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable): + tot_num_cuts = len(dataloader) + num_cuts = 0 + results = [] # a list of pair (ref_words, hyp_words) + for batch_idx, batch in enumerate(dataloader): + + feature = batch['inputs'] + supervisions = batch['supervisions'] + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + torch.floor_divide(supervisions['start_frame'], + model.subsampling_factor), + torch.floor_divide(supervisions['num_frames'], + model.subsampling_factor)), 1).to(torch.int32) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + texts = supervisions['text'] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is [N, T, C] + #feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + with torch.no_grad(): + nnet_output = model(feature) + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, + 1) # now nnet_output is [N, T, C] + + blank_bias = -3.0 + nnet_output[:, :, 0] += blank_bias + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + # assert HLG.is_cuda() + assert HLG.device == nnet_output.device, \ + f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" + # TODO(haowen): with a small `beam`, we may get empty `target_graph`, + # thus `tot_scores` will be `inf`. Definitely we need to handle this later. + lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) + + # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0) + best_paths = k2.shortest_path(lattices, use_double_scores=True) + assert best_paths.shape[0] == len(texts) + hyps = get_texts(best_paths, indices) + assert len(hyps) == len(texts) + + for i in range(len(texts)): + hyp_words = [symbols.get(x) for x in hyps[i]] + ref_words = texts[i].split(' ') + results.append((ref_words, hyp_words)) + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( + batch_idx, num_cuts, tot_num_cuts, + float(num_cuts) / tot_num_cuts * 100)) + + num_cuts += len(texts) + + return results + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('--epoch', type=int, default=20, + help='the checkpoint for loading.') + + parser.add_argument('--mode', type=str, default='TEST', + help='the mode to test.') + + args = parser.parse_args() + + exp_dir = Path('exp-lstm-adam-ctc-musan-with-crdnn') + setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') + + # load L, G, symbol_table + lang_dir = Path('data/lang_nosp') + symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + phone_ids = get_phone_symbols(phone_symbol_table) + phone_ids_with_blank = [0] + phone_ids + ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) + + if not os.path.exists(lang_dir / 'HLG.pt'): + print("Loading L_disambig.fst.txt") + with open(lang_dir / 'L_disambig.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + print("Loading G.fst.txt") + with open(lang_dir / 'G.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + HLG = compile_HLG(L=L, + G=G, + H=ctc_topo, + labels_disambig_id_start=first_phone_disambig_id, + aux_labels_disambig_id_start=first_word_disambig_id) + torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') + else: + print("Loading pre-compiled HLG") + d = torch.load(lang_dir / 'HLG.pt') + HLG = k2.Fsa.from_dict(d) + + # load dataset + feature_dir = Path('exp/data') + print("About to get test cuts") + cuts_test = CutSet.from_json(feature_dir / 'cuts_{}.json.gz'.format(args.mode)) + + print("About to create test dataset") + test = K2SpeechRecognitionDataset(cuts_test) + sampler = SingleCutSampler(cuts_test, max_frames=100000) + print("About to create test dataloader") + test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) + + # if not torch.cuda.is_available(): + # logging.error('No GPU detected!') + # sys.exit(-1) + + print("About to load model") + # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N + # device = torch.device('cuda', 1) + device = torch.device('cuda') + + crdnn = CRDNN( + input_size=80, + time_pooling=True) + + linear = Linear( + input_size=512, + n_neurons=len(phone_ids) + 1) + + model = crdnn_model(80, len(phone_ids)+1, 2, crdnn, linear) + + checkpoint = os.path.join(exp_dir, 'epoch-{}.pt'.format(args.epoch)) + load_checkpoint(checkpoint, model) + model.to(device) + model.eval() + + print("convert HLG to device") + HLG = HLG.to(device) + HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) + HLG.requires_grad_(False) + print("About to decode") + results = decode(dataloader=test_dl, + model=model, + device=device, + HLG=HLG, + symbols=symbol_table) + s = '' + for ref, hyp in results: + s += f'ref={ref}\n' + s += f'hyp={hyp}\n' + #logging.info(s) + results = [([one for one in n[0] if one], [one for one in n[1] if one]) for n in results] + # compute WER + dists = [edit_distance(r, h) for r, h in results] + errors = { + key: sum(dist[key] for dist in dists) + for key in ['sub', 'ins', 'del', 'total'] + } + total_words = sum(len(ref) for ref, _ in results) + + logging.info( + f'%PER {errors["total"] / total_words:.2%} ' + f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]' + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/egs/timit/asr/simple_v1/ctc_train_with_CRDNN.py b/egs/timit/asr/simple_v1/ctc_train_with_CRDNN.py new file mode 100644 index 00000000..109e1110 --- /dev/null +++ b/egs/timit/asr/simple_v1/ctc_train_with_CRDNN.py @@ -0,0 +1,456 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu) +# 2021 Xiaomi Corporation (authors: Mingshuang Luo) +# Apache 2.0 + +# Notice: before you run this script, you should install speechbrain first. +# You can install speechbrain by "pip install speechbrain". + +import k2 +import logging +import math +import numpy as np +import os +import sys + +import torch +import torch.nn as nn +import torch.optim as optim +from datetime import datetime +from pathlib import Path +from torch.nn.utils import clip_grad_value_, clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, Optional, Tuple + +from lhotse import CutSet +from lhotse.dataset import CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler +from lhotse.utils import fix_random_seed +from snowfall.common import describe +from snowfall.common import get_phone_symbols +from snowfall.common import load_checkpoint, save_checkpoint +from snowfall.common import save_training_info +from snowfall.common import setup_logger +from snowfall.models import AcousticModel +from snowfall.training.ctc_graph import CtcTrainingGraphCompiler + +from speechbrain.lobes.models.CRDNN import CRDNN +from speechbrain.nnet.linear import Linear + +class crdnn_model(nn.Module): + def __init__(self, num_features:int, + num_classes:int, + subsampling_factor: int, + crdnn, + linear): + + super(crdnn_model, self).__init__() + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + + self.crdnn = crdnn + self.linear=linear + + def forward(self, x): + x = self.crdnn(x) + x = self.linear(x) + x = x.transpose(1,2) + x = nn.functional.log_softmax(x, dim=1) + + return x + +def get_tot_objf_and_num_frames(tot_scores: torch.Tensor, + frames_per_seq: torch.Tensor + ) -> Tuple[float, int, int]: + ''' Figures out the total score(log-prob) over all successful supervision segments + (i.e. those for which the total score wasn't -infinity), and the corresponding + number of frames of neural net output + Args: + tot_scores: a Torch tensor of shape (num_segments,) containing total scores + from forward-backward + frames_per_seq: a Torch tensor of shape (num_segments,) containing the number of + frames for each segment + Returns: + Returns a tuple of 3 scalar tensors: (tot_score, ok_frames, all_frames) + where ok_frames is the frames for successful (finite) segments, and + all_frames is the frames for all segments (finite or not). + ''' + mask = torch.ne(tot_scores, -math.inf) + # finite_indexes is a tensor containing successful segment indexes, e.g. + # [ 0 1 3 4 5 ] + finite_indexes = torch.nonzero(mask).squeeze(1) + if False: + bad_indexes = torch.nonzero(~mask).squeeze(1) + if bad_indexes.shape[0] > 0: + print("Bad indexes: ", bad_indexes, ", bad lengths: ", + frames_per_seq[bad_indexes], " vs. max length ", + torch.max(frames_per_seq), ", avg ", + (torch.sum(frames_per_seq) / frames_per_seq.numel())) + + ok_frames = frames_per_seq[finite_indexes].sum() + all_frames = frames_per_seq.sum() + return (tot_scores[finite_indexes].sum(), ok_frames, all_frames) + + +def get_objf(batch: Dict, + model: AcousticModel, + device: torch.device, + graph_compiler: CtcTrainingGraphCompiler, + training: bool, + optimizer: Optional[torch.optim.Optimizer] = None): + feature = batch['inputs'] + supervisions = batch['supervisions'] + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + torch.floor_divide(supervisions['start_frame'], + 2), + torch.floor_divide(supervisions['num_frames'], + 2)), 1).to(torch.int32) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + + texts = supervisions['text'] + texts = [texts[idx] for idx in indices] + assert feature.ndim == 3 + # print(supervision_segments[:, 1] + supervision_segments[:, 2]) + + feature = feature.to(device) + # at entry, feature is [N, T, C] + #feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + if training: + nnet_output = model(feature) + else: + with torch.no_grad(): + nnet_output = model(feature) + + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] + #print('nnet_output: ', nnet_output.size(), nnet_output) + decoding_graph = graph_compiler.compile(texts).to(device) + + # nnet_output2 = nnet_output.clone() + # blank_bias = -7.0 + # nnet_output2[:,:,0] += blank_bias + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + assert decoding_graph.is_cuda() + assert decoding_graph.device == device + assert nnet_output.device == device + + target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) + + tot_scores = target_graph.get_tot_scores( + log_semiring=True, + use_double_scores=True) + + (tot_score, tot_frames, + all_frames) = get_tot_objf_and_num_frames(tot_scores, + supervision_segments[:, 2]) + + if training: + optimizer.zero_grad() + (-tot_score).backward() + #clip_grad_value_(model.parameters(), 5.0) + clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) + optimizer.step() + + ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( + ), all_frames.cpu().item() + return ans + + +def get_validation_objf(dataloader: torch.utils.data.DataLoader, + model: AcousticModel, device: torch.device, + graph_compiler: CtcTrainingGraphCompiler): + total_objf = 0. + total_frames = 0. # for display only + total_all_frames = 0. # all frames including those seqs that failed. + + model.eval() + + for batch_idx, batch in enumerate(dataloader): + objf, frames, all_frames = get_objf(batch, model, device, + graph_compiler, False) + total_objf += objf + total_frames += frames + total_all_frames += all_frames + + return total_objf, total_frames, total_all_frames + + +def train_one_epoch(dataloader: torch.utils.data.DataLoader, + valid_dataloader: torch.utils.data.DataLoader, + model: AcousticModel, device: torch.device, + graph_compiler: CtcTrainingGraphCompiler, + optimizer: torch.optim.Optimizer, + current_epoch: int, + tb_writer: SummaryWriter, + num_epochs: int, + global_batch_idx_train: int): + total_objf, total_frames, total_all_frames = 0., 0., 0. + valid_average_objf = float('inf') + time_waiting_for_batch = 0 + prev_timestamp = datetime.now() + + model.train() + for batch_idx, batch in enumerate(dataloader): + global_batch_idx_train += 1 + timestamp = datetime.now() + time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() + curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \ + get_objf(batch, model, device, graph_compiler, True, optimizer) + + total_objf += curr_batch_objf + total_frames += curr_batch_frames + total_all_frames += curr_batch_all_frames + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, epoch {}/{} ' + 'global average objf: {:.6f} over {} ' + 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' + 'avg time waiting for batch {:.3f}s'.format( + batch_idx, current_epoch, num_epochs, + total_objf / total_frames, total_frames, + 100.0 * total_frames / total_all_frames, + curr_batch_objf / (curr_batch_frames + 0.001), + curr_batch_frames, + 100.0 * curr_batch_frames / curr_batch_all_frames, + time_waiting_for_batch / max(1, batch_idx))) + + tb_writer.add_scalar('train/global_average_objf', + total_objf / total_frames, global_batch_idx_train) + + tb_writer.add_scalar('train/current_batch_average_objf', + curr_batch_objf / (curr_batch_frames + 0.001), + global_batch_idx_train) + # if batch_idx >= 10: + # print("Exiting early to get profile info") + # sys.exit(0) + + if batch_idx > 0 and batch_idx % 10 == 0: + total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( + dataloader=valid_dataloader, + model=model, + device=device, + graph_compiler=graph_compiler) + valid_average_objf = total_valid_objf / total_valid_frames + model.train() + logging.info( + 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' + .format(valid_average_objf, + total_valid_frames, + 100.0 * total_valid_frames / total_valid_all_frames)) + + tb_writer.add_scalar('train/global_valid_average_objf', + valid_average_objf, + global_batch_idx_train) + prev_timestamp = datetime.now() + if batch_idx >= 50: + break + return total_objf / total_frames, valid_average_objf, global_batch_idx_train + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('--epochs', type=int, default=100, help='the number of epoch for training.') + + parser.add_argument('--duration', type=int, default=200, help='the max duration in a batch for training.') + + args = parser.parse_args() + + fix_random_seed(42) + + start_epoch = 7 + num_epochs = args.epochs + + exp_dir = 'exp-lstm-adam-ctc-musan-with-crdnn' + setup_logger('{}/log/log-train'.format(exp_dir)) + tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') + + # load L, G, symbol_table + lang_dir = Path('data/lang_nosp') + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + word_symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') + + logging.info("Loading L.fst") + if (lang_dir / 'Linv.pt').exists(): + L_inv = k2.Fsa.from_dict(torch.load(lang_dir / 'Linv.pt')) + else: + with open(lang_dir / 'L.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + L_inv = k2.arc_sort(L.invert_()) + torch.save(L_inv.as_dict(), lang_dir / 'Linv.pt') + + graph_compiler = CtcTrainingGraphCompiler( + L_inv=L_inv, + phones=phone_symbol_table, + words=word_symbol_table + ) + phone_ids = get_phone_symbols(phone_symbol_table) + + # load dataset + feature_dir = Path('exp/data') + logging.info("About to get train cuts") + cuts_train = CutSet.from_json(feature_dir / + 'cuts_TRAIN.json.gz') + logging.info("About to get dev cuts") + cuts_dev = CutSet.from_json(feature_dir / 'cuts_TEST.json.gz') + logging.info("About to get Musan cuts") + cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz') + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cuts_train, + cut_transforms=[ + CutConcatenate(), + CutMix( + cuts=cuts_musan, + prob=0.5, + snr=(10, 20) + ) + ] + ) + train_sampler = SingleCutSampler( + cuts_train, + #max_frames=180000, + max_duration=args.duration, + shuffle=True, + ) + logging.info("About to create train dataloader") + train_dl = torch.utils.data.DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=4 + ) + logging.info("About to create dev dataset") + validate = K2SpeechRecognitionDataset(cuts_dev) + valid_sampler = SingleCutSampler(cuts_dev, max_frames=90000) + logging.info("About to create dev dataloader") + valid_dl = torch.utils.data.DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=1 + ) + + if not torch.cuda.is_available(): + logging.error('No GPU detected!') + sys.exit(-1) + + logging.info("About to create model") + device_id = 0 + device = torch.device('cuda', device_id) + crdnn = CRDNN( + input_size=80, + time_pooling=True) + + linear = Linear( + input_size=512, + n_neurons=len(phone_ids) + 1) + + model = crdnn_model(80, len(phone_ids)+1, 2, crdnn, linear) + model = model.to(device) + + describe(model) + + learning_rate = 0.7e-3 + optimizer = optim.AdamW(model.parameters(), + lr=learning_rate, + weight_decay=5e-4) + + best_objf = np.inf + best_valid_objf = np.inf + best_epoch = start_epoch + best_model_path = os.path.join(exp_dir, 'best_model.pt') + best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') + global_batch_idx_train = 0 # for logging only + + if start_epoch > 0: + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) + ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer) + best_objf = ckpt['objf'] + best_valid_objf = ckpt['valid_objf'] + global_batch_idx_train = ckpt['global_batch_idx_train'] + logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") + + for epoch in range(start_epoch, args.epochs): + train_sampler.set_epoch(epoch) + curr_learning_rate = 1e-3 + # curr_learning_rate = learning_rate * pow(0.4, epoch) + # for param_group in optimizer.param_groups: + # param_group['lr'] = curr_learning_rate + + tb_writer.add_scalar('learning_rate', curr_learning_rate, epoch) + + logging.info('epoch {}, learning rate {}'.format( + epoch, curr_learning_rate)) + objf, valid_objf, global_batch_idx_train = train_one_epoch(dataloader=train_dl, + valid_dataloader=valid_dl, + model=model, + device=device, + graph_compiler=graph_compiler, + optimizer=optimizer, + current_epoch=epoch, + tb_writer=tb_writer, + num_epochs=num_epochs, + global_batch_idx_train=global_batch_idx_train) + # the lower, the better + if valid_objf < best_valid_objf: + best_valid_objf = valid_objf + best_objf = objf + best_epoch = epoch + save_checkpoint(filename=best_model_path, + model=model, + epoch=epoch, + optimizer=None, + scheduler=None, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train) + save_training_info(filename=best_epoch_info_filename, + model_path=best_model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=best_objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch) + + # we always save the model for every epoch + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) + save_checkpoint(filename=model_path, + model=model, + optimizer=optimizer, + scheduler=None, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train) + epoch_info_filename = os.path.join(exp_dir, + 'epoch-{}-info'.format(epoch)) + save_training_info(filename=epoch_info_filename, + model_path=model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch) + + logging.warning('Done') + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() \ No newline at end of file From 54e4476bbd8d87f7da7cd752ecded167c42f904f Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Mon, 13 Sep 2021 13:28:25 +0800 Subject: [PATCH 10/10] Rename --- .../simple_v1/{ctc_decode_with_CRDNN.py => ctc_crdnn_decode.py} | 0 .../asr/simple_v1/{ctc_train_with_CRDNN.py => ctc_crdnn_train.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename egs/timit/asr/simple_v1/{ctc_decode_with_CRDNN.py => ctc_crdnn_decode.py} (100%) rename egs/timit/asr/simple_v1/{ctc_train_with_CRDNN.py => ctc_crdnn_train.py} (100%) diff --git a/egs/timit/asr/simple_v1/ctc_decode_with_CRDNN.py b/egs/timit/asr/simple_v1/ctc_crdnn_decode.py similarity index 100% rename from egs/timit/asr/simple_v1/ctc_decode_with_CRDNN.py rename to egs/timit/asr/simple_v1/ctc_crdnn_decode.py diff --git a/egs/timit/asr/simple_v1/ctc_train_with_CRDNN.py b/egs/timit/asr/simple_v1/ctc_crdnn_train.py similarity index 100% rename from egs/timit/asr/simple_v1/ctc_train_with_CRDNN.py rename to egs/timit/asr/simple_v1/ctc_crdnn_train.py