From 660facf088ded9f084cc1a24a1f00f64ce5f6918 Mon Sep 17 00:00:00 2001 From: freewym Date: Thu, 20 Jul 2023 19:05:26 -0400 Subject: [PATCH] allows dictionary files w/o the counts column; rename task's --max-num-expansions-per-step to --transducer-max-num-expansions-per-step (same as generation's) and its default is 20; prints out word counts after WER evaluation; fixes decoding log write out --- espresso/speech_recognize.py | 9 +++-- espresso/tasks/speech_recognition.py | 40 ++++++++++++++++--- .../run_transformer_transducer.sh | 1 + fairseq/dataclass/configs.py | 2 +- 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index e73801e7b..e0bc56164 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -64,6 +64,7 @@ def _main(cfg, output_file): datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=output_file, + force=True, ) logger = logging.getLogger("espresso.speech_recognize") if output_file is not sys.stdout: # also print to stdout @@ -359,8 +360,8 @@ def decode_fn(x): with open( os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8" ) as f: - res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( - *(scorer.wer()) + res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}".format( + *(scorer.wer()), scorer.tot_word_count() ) logger.info(header + res) f.write(res + "\n") @@ -370,8 +371,8 @@ def decode_fn(x): with open( os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8" ) as f: - res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format( - *(scorer.cer()) + res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #chars={:d}".format( + *(scorer.cer()), scorer.tot_char_count() ) logger.info(" " * len(header) + res) f.write(res + "\n") diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 484d6813c..873f4c829 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -7,6 +7,7 @@ import json import logging import os +import tempfile from argparse import Namespace from collections import OrderedDict from dataclasses import dataclass, field @@ -84,11 +85,11 @@ class SpeechRecognitionEspressoConfig(FairseqDataclass): "moving EOS to the beginning of that) as input feeding" }, ) - max_num_expansions_per_step: int = field( - default=2, + transducer_max_num_expansions_per_step: int = field( + default=II("generation.transducer_max_num_expansions_per_step"), metadata={ "help": "the maximum number of non-blank expansions in a single " - "time step of decoding; only relavant when training with transducer loss" + "time step of decoding in validation; only relavant when training with transducer loss" }, ) specaugment_config: Optional[str] = field( @@ -340,6 +341,7 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs): """ # load dictionaries dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict + dict_path = cls._maybe_add_pseudo_counts_to_dict(dict_path) enable_blank = ( True if cfg.criterion_name in ["transducer_loss", "ctc_loss"] else False ) @@ -376,13 +378,39 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs): feat_dim = src_dataset.feat_dim if cfg.word_dict is not None: - word_dict = cls.load_dictionary(cfg.word_dict, enable_bos=False) + word_dict_path = cfg.word_dict + word_dict_path = cls._maybe_add_pseudo_counts_to_dict(word_dict_path) + word_dict = cls.load_dictionary(word_dict_path, enable_bos=False) logger.info("word dictionary: {} types".format(len(word_dict))) return cls(cfg, tgt_dict, feat_dim, word_dict=word_dict) else: return cls(cfg, tgt_dict, feat_dim) + @classmethod + def _maybe_add_pseudo_counts_to_dict(cls, dict_path): + with open(dict_path, "r", encoding="utf-8") as f: + split_list = f.readline().rstrip().rsplit(" ", 1) + if len(split_list) == 2: + try: + int(split_list[1]) + return dict_path + except ValueError: + pass + logger.info(f"No counts detected in {dict_path}. Adding pseudo counts...") + with open(dict_path, "r", encoding="utf-8") as fin, tempfile.NamedTemporaryFile( + "w", encoding="utf-8", delete=False + ) as fout: + for i, line in enumerate(fin): + line = line.rstrip() + if len(line) == 0: + logger.warning( + f"Empty at line {i+1} in the dictionary {dict_path}, skipping it" + ) + continue + print(line + " 1", file=fout) + return fout.name + def load_dataset( self, split: str, @@ -457,7 +485,7 @@ def build_model(self, cfg: DictConfig, from_checkpoint=False): self.decoder_for_validation = TransducerGreedyDecoder( [model], self.target_dictionary, - max_num_expansions_per_step=self.cfg.max_num_expansions_per_step, + max_num_expansions_per_step=self.cfg.transducer_max_num_expansions_per_step, bos=( self.target_dictionary.bos() if self.cfg.include_eos_in_transducer_loss @@ -528,7 +556,7 @@ def build_generator( beam_size=getattr(args, "beam", 1), normalize_scores=(not getattr(args, "unnormalized", False)), max_num_expansions_per_step=getattr( - args, "transducer_max_num_expansions_per_step", 2 + args, "transducer_max_num_expansions_per_step", 20 ), expansion_beta=getattr(args, "transducer_expansion_beta", 0), expansion_gamma=getattr(args, "transducer_expansion_gamma", None), diff --git a/examples/asr_librispeech/run_transformer_transducer.sh b/examples/asr_librispeech/run_transformer_transducer.sh index 29b82b819..4157c9590 100755 --- a/examples/asr_librispeech/run_transformer_transducer.sh +++ b/examples/asr_librispeech/run_transformer_transducer.sh @@ -261,6 +261,7 @@ if [ ${stage} -le 8 ]; then --num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \ --gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \ --path $path --beam 5 --temperature 1.3 --criterion-name transducer_loss \ + --transducer-max-num-expansions-per-step 20 \ --transducer-expansion-beta 2 --transducer-expansion-gamma 2.3 --transducer-prefix-alpha 1 \ --results-path $decode_dir $opts diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 7f9074c78..858f19c29 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -1051,7 +1051,7 @@ class GenerationConfig(FairseqDataclass): ) # for decoding transducer models transducer_max_num_expansions_per_step: Optional[int] = field( - default=2, + default=20, metadata={ "help": "the maximum number of non-blank expansions in a single time step" },