Skip to content

Commit

Permalink
allows dictionary files w/o the counts column; rename task's
Browse files Browse the repository at this point in the history
--max-num-expansions-per-step to
--transducer-max-num-expansions-per-step (same as generation's) and its default is 20; prints out word counts
after WER evaluation; fixes decoding log write out
  • Loading branch information
freewym committed Jul 24, 2023
1 parent e0e61e2 commit 660facf
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 11 deletions.
9 changes: 5 additions & 4 deletions espresso/speech_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _main(cfg, output_file):
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=output_file,
force=True,
)
logger = logging.getLogger("espresso.speech_recognize")
if output_file is not sys.stdout: # also print to stdout
Expand Down Expand Up @@ -359,8 +360,8 @@ def decode_fn(x):
with open(
os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
) as f:
res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format(
*(scorer.wer())
res = "WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}".format(
*(scorer.wer()), scorer.tot_word_count()
)
logger.info(header + res)
f.write(res + "\n")
Expand All @@ -370,8 +371,8 @@ def decode_fn(x):
with open(
os.path.join(cfg.common_eval.results_path, fn), "w", encoding="utf-8"
) as f:
res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%".format(
*(scorer.cer())
res = "CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #chars={:d}".format(
*(scorer.cer()), scorer.tot_char_count()
)
logger.info(" " * len(header) + res)
f.write(res + "\n")
Expand Down
40 changes: 34 additions & 6 deletions espresso/tasks/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
import os
import tempfile
from argparse import Namespace
from collections import OrderedDict
from dataclasses import dataclass, field
Expand Down Expand Up @@ -84,11 +85,11 @@ class SpeechRecognitionEspressoConfig(FairseqDataclass):
"moving EOS to the beginning of that) as input feeding"
},
)
max_num_expansions_per_step: int = field(
default=2,
transducer_max_num_expansions_per_step: int = field(
default=II("generation.transducer_max_num_expansions_per_step"),
metadata={
"help": "the maximum number of non-blank expansions in a single "
"time step of decoding; only relavant when training with transducer loss"
"time step of decoding in validation; only relavant when training with transducer loss"
},
)
specaugment_config: Optional[str] = field(
Expand Down Expand Up @@ -340,6 +341,7 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs):
"""
# load dictionaries
dict_path = os.path.join(cfg.data, "dict.txt") if cfg.dict is None else cfg.dict
dict_path = cls._maybe_add_pseudo_counts_to_dict(dict_path)
enable_blank = (
True if cfg.criterion_name in ["transducer_loss", "ctc_loss"] else False
)
Expand Down Expand Up @@ -376,13 +378,39 @@ def setup_task(cls, cfg: SpeechRecognitionEspressoConfig, **kwargs):
feat_dim = src_dataset.feat_dim

if cfg.word_dict is not None:
word_dict = cls.load_dictionary(cfg.word_dict, enable_bos=False)
word_dict_path = cfg.word_dict
word_dict_path = cls._maybe_add_pseudo_counts_to_dict(word_dict_path)
word_dict = cls.load_dictionary(word_dict_path, enable_bos=False)
logger.info("word dictionary: {} types".format(len(word_dict)))
return cls(cfg, tgt_dict, feat_dim, word_dict=word_dict)

else:
return cls(cfg, tgt_dict, feat_dim)

@classmethod
def _maybe_add_pseudo_counts_to_dict(cls, dict_path):
with open(dict_path, "r", encoding="utf-8") as f:
split_list = f.readline().rstrip().rsplit(" ", 1)
if len(split_list) == 2:
try:
int(split_list[1])
return dict_path
except ValueError:
pass
logger.info(f"No counts detected in {dict_path}. Adding pseudo counts...")
with open(dict_path, "r", encoding="utf-8") as fin, tempfile.NamedTemporaryFile(
"w", encoding="utf-8", delete=False
) as fout:
for i, line in enumerate(fin):
line = line.rstrip()
if len(line) == 0:
logger.warning(
f"Empty at line {i+1} in the dictionary {dict_path}, skipping it"
)
continue
print(line + " 1", file=fout)
return fout.name

def load_dataset(
self,
split: str,
Expand Down Expand Up @@ -457,7 +485,7 @@ def build_model(self, cfg: DictConfig, from_checkpoint=False):
self.decoder_for_validation = TransducerGreedyDecoder(
[model],
self.target_dictionary,
max_num_expansions_per_step=self.cfg.max_num_expansions_per_step,
max_num_expansions_per_step=self.cfg.transducer_max_num_expansions_per_step,
bos=(
self.target_dictionary.bos()
if self.cfg.include_eos_in_transducer_loss
Expand Down Expand Up @@ -528,7 +556,7 @@ def build_generator(
beam_size=getattr(args, "beam", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
max_num_expansions_per_step=getattr(
args, "transducer_max_num_expansions_per_step", 2
args, "transducer_max_num_expansions_per_step", 20
),
expansion_beta=getattr(args, "transducer_expansion_beta", 0),
expansion_gamma=getattr(args, "transducer_expansion_gamma", None),
Expand Down
1 change: 1 addition & 0 deletions examples/asr_librispeech/run_transformer_transducer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ if [ ${stage} -le 8 ]; then
--num-shards 1 --shard-id 0 --dict $dict --bpe sentencepiece --sentencepiece-model ${sentencepiece_model}.model \
--gen-subset $dataset --max-source-positions 9999 --max-target-positions 999 \
--path $path --beam 5 --temperature 1.3 --criterion-name transducer_loss \
--transducer-max-num-expansions-per-step 20 \
--transducer-expansion-beta 2 --transducer-expansion-gamma 2.3 --transducer-prefix-alpha 1 \
--results-path $decode_dir $opts

Expand Down
2 changes: 1 addition & 1 deletion fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ class GenerationConfig(FairseqDataclass):
)
# for decoding transducer models
transducer_max_num_expansions_per_step: Optional[int] = field(
default=2,
default=20,
metadata={
"help": "the maximum number of non-blank expansions in a single time step"
},
Expand Down

0 comments on commit 660facf

Please sign in to comment.