diff --git a/espresso/speech_recognize.py b/espresso/speech_recognize.py index 852ad6da9..e73801e7b 100755 --- a/espresso/speech_recognize.py +++ b/espresso/speech_recognize.py @@ -334,7 +334,7 @@ def decode_fn(x): if save_attention_plot: logger.info("Saved attention plots in " + save_dir) - if has_target: + if has_target and hasattr(task.datasets[cfg.dataset.gen_subset], "tgt"): scorer.add_ordered_utt_list(task.datasets[cfg.dataset.gen_subset].tgt.utt_ids) fn = "decoded_char_results.txt" diff --git a/espresso/tools/compute_wer.py b/espresso/tools/compute_wer.py index 739a93a5e..14e3dd982 100755 --- a/espresso/tools/compute_wer.py +++ b/espresso/tools/compute_wer.py @@ -75,7 +75,11 @@ def main(args): wer_counter = Counter() with open(args.hyp_text, "r", encoding="utf-8") as f: for line in f: - utt_id, text = line.strip().split(None, 1) + res = line.strip().split(None, 1) + if len(res) == 2: + utt_id, text = res + else: + utt_id, text = res[0], "" assert utt_id in refs, utt_id ref, hyp = refs[utt_id], text diff --git a/espresso/tools/transducer_base_decoder.py b/espresso/tools/transducer_base_decoder.py index 3567df145..21dea5f29 100644 --- a/espresso/tools/transducer_base_decoder.py +++ b/espresso/tools/transducer_base_decoder.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -33,8 +33,8 @@ def __init__( Args: models (List[~fairseq.models.FairseqModel]): ensemble of models dictionary (~fairseq.data.Dictionary): dictionary - max_len (int, optional): the maximum length of the generated output - (not including end-of-sentence) (default: 0, no limit) + max_len (int, optional): the maximum length of the encoder output + that can emit tokens (default: 0, no limit) max_num_expansions_per_step (int, optional): the maximum number of non-blank expansions in a single time step (default: 2) temperature (float, optional): temperature, where values @@ -95,7 +95,9 @@ def cuda(self): return self @torch.no_grad() - def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + def decode( + self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Generate a batch of 1-best hypotheses. Match the API of other fairseq generators. Normally called for validation during training. @@ -139,6 +141,24 @@ def generate( @torch.no_grad() def _generate( self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: - # should return a tuple of tokens, scores and alignments + ) -> Tuple[ + Union[Tensor, List[Tensor]], + Union[Tensor, List[Tensor]], + Optional[Union[Tensor, List[Tensor]]], + ]: + """Implement the algorithm here. + Should return a tuple of tokens, scores and alignments. + + Args: + feature (Tensor): feature of shape + `(batch, feature_length, feature_dim)` + feature_lens (Tensor, optional): feature lengths of shape `(batch)` + + Returns: + tokens (LongTensor or List[LongTensor]): token sequences of shape + `(batch, max_dec_out_length)` + scores (FloatTensor or List[FloatTensor]): scores of shape `(batch)` + alignments (LongTensor or List[LongTensor], optional): alignments of + shape `(batch, max_enc_out_length, max_num_tokens_per_step)` + """ raise NotImplementedError diff --git a/espresso/tools/transducer_beam_search_decoder.py b/espresso/tools/transducer_beam_search_decoder.py index 013daed51..2a996b651 100644 --- a/espresso/tools/transducer_beam_search_decoder.py +++ b/espresso/tools/transducer_beam_search_decoder.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch from torch import Tensor @@ -46,8 +46,8 @@ def __init__( models (List[~fairseq.models.FairseqModel]): ensemble of models dictionary (~fairseq.data.Dictionary): dictionary beam_size (int, optional): beam width (default: 1) - max_len (int, optional): the maximum length of the generated output - (not including end-of-sentence) (default: 0, no limit) + max_len (int, optional): the maximum length of the encoder output + that can emit tokens (default: 0, no limit) max_num_expansions_per_step (int, optional): the maximum number of non-blank expansions in a single time step (default: 2) expansion_beta (int, optional): maximum number of prefix expansions allowed, @@ -63,8 +63,8 @@ def __init__( prefix_alpha (int, optional): maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 in order to reduce expensive beam search cost later (default: 1) - normalize_scores (bool, optional): normalize scores by the length - of the output including blank (default: True) + normalize_scores (bool, optional): if True normalize scores by the length + of the output including blank when sorting/ranking hyps (default: True) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) @@ -101,7 +101,9 @@ def __init__( self.normalize_scores = normalize_scores @torch.no_grad() - def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + def decode( + self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: assert not self.print_alignment tokens_list, scores_list, _ = self._generate(sample, **kwargs) @@ -149,7 +151,7 @@ def generate( @torch.no_grad() def _generate( self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None - ): + ) -> Tuple[List[Tensor], List[Tensor], Optional[List[Tensor]]]: net_input = sample["net_input"] src_tokens = net_input["src_tokens"] bsz, src_len = src_tokens.size()[:2] @@ -158,14 +160,16 @@ def _generate( encoder_outs = self.model.encoder.forward_torchscript(net_input) enc_out = encoder_outs["encoder_out"][0].transpose(0, 1) # B x T x C enc_out_lengths = encoder_outs["src_lengths"][0] # B - sequences_list, scores_list, alignments_list = [], [], [] + sequences_list, scores_list = [], [] + alignments_list = [] if self.print_alignment else None for i in range(bsz): sequences, scores, alignments = self._generate_one_example( enc_out[i : i + 1, :, :], enc_out_lengths[i], bos_token=bos_token ) sequences_list.append(sequences) scores_list.append(scores) - alignments_list.append(alignments) + if self.print_alignment: + alignments_list.append(alignments) return sequences_list, scores_list, alignments_list @@ -185,7 +189,7 @@ def _generate_one_example( if self.print_alignment: # token alignments for the final best hypthesis. +1 for blank alignments = enc_out.new_full( - (enc_out_length, self.max_num_expansions_per_step + 1), + (1, enc_out_length, self.max_num_expansions_per_step + 1), self.pad, dtype=torch.long, ) # T x N (max num tokens per time step) @@ -336,102 +340,100 @@ def _generate_one_example( self.beam_size, normalize_by_length=self.normalize_scores ) break - else: - # forward the decoder with `k_expanded_hyps_nonblank` - self.model.decoder.set_incremental_state( - incremental_state, + # forward the decoder with `k_expanded_hyps_nonblank` + self.model.decoder.set_incremental_state( + incremental_state, + "cached_state", + k_expanded_hyps_nonblank.cached_state, + ) + dec_out = self.model.decoder.extract_features( + k_expanded_hyps_nonblank.prev_tokens.unsqueeze(1), + incremental_state=incremental_state, + )[ + 0 + ] # B x 1 x H + k_expanded_hyps_nonblank.cached_state = ( + self.model.decoder.get_incremental_state( + incremental_state, "cached_state" + ) + ) # update `cached_state` + + if self.lm_model is not None: + self.lm_model.decoder.set_incremental_state( + lm_incremental_state, "cached_state", - k_expanded_hyps_nonblank.cached_state, + k_expanded_hyps_nonblank.lm_cached_state, ) - dec_out = self.model.decoder.extract_features( - k_expanded_hyps_nonblank.prev_tokens.unsqueeze(1), - incremental_state=incremental_state, + prev_tokens = k_expanded_hyps_nonblank.prev_tokens + lm_prev_tokens = ( + torch.where( + prev_tokens > self.blank, prev_tokens - 1, prev_tokens + ) + if self.no_blank_in_lm + else prev_tokens + ) + lm_dec_out = self.lm_model.extract_features( + lm_prev_tokens.unsqueeze(1), + incremental_state=lm_incremental_state, )[ 0 - ] # B x 1 x H - k_expanded_hyps_nonblank.cached_state = ( - self.model.decoder.get_incremental_state( - incremental_state, "cached_state" + ] # B x 1 x H' + k_expanded_hyps_nonblank.lm_cached_state = ( + self.lm_model.decoder.get_incremental_state( + lm_incremental_state, "cached_state" ) - ) # update `cached_state` - - if self.lm_model is not None: - self.lm_model.decoder.set_incremental_state( - lm_incremental_state, - "cached_state", - k_expanded_hyps_nonblank.lm_cached_state, - ) - prev_tokens = k_expanded_hyps_nonblank.prev_tokens - lm_prev_tokens = ( - torch.where( - prev_tokens > self.blank, prev_tokens - 1, prev_tokens - ) - if self.no_blank_in_lm - else prev_tokens + ) # update `lm_cached_state` + else: + lm_dec_out = None + + k_expanded_hyps_nonblank.update_dec_out_( + dec_out, lm_dec_out=lm_dec_out + ) # update `dec_out` and `lm_dec_out` + + if ( + expansion_idx < self.max_num_expansions_per_step - 1 + ): # not the last round of expansion within this time step + # prepare for the next round of expansion within this time step + hyps = k_expanded_hyps_nonblank + else: # the last round of expansion within this time step + # add blank probability to non-blank hyps, combine and prune the hyps for the next time step + logits = ( + self.model.joint( + enc_out_this_step.expand(dec_out.size(0), -1, -1), + dec_out, + apply_output_layer=True, ) - lm_dec_out = self.lm_model.extract_features( - lm_prev_tokens.unsqueeze(1), - incremental_state=lm_incremental_state, - )[ - 0 - ] # B x 1 x H' - k_expanded_hyps_nonblank.lm_cached_state = ( - self.lm_model.decoder.get_incremental_state( - lm_incremental_state, "cached_state" - ) - ) # update `lm_cached_state` - else: - lm_dec_out = None - - k_expanded_hyps_nonblank.update_dec_out_( - dec_out, lm_dec_out=lm_dec_out - ) # update `dec_out` and `lm_dec_out` - - if ( - expansion_idx < self.max_num_expansions_per_step - 1 - ): # not the last round of expansion within this time step - # prepare for the next round of expansion within this time step - hyps = k_expanded_hyps_nonblank - else: # the last round of expansion within this time step - # add blank probability to non-blank hyps, combine and prune the hyps for the next time step - logits = ( - self.model.joint( - enc_out_this_step.expand(dec_out.size(0), -1, -1), - dec_out, - apply_output_layer=True, - ) - .squeeze(2) - .squeeze(1) - ) # B x 1 x 1 x V -> B x V - lprobs = self.model.get_normalized_probs( - (logits.div_(self.temperature), None), log_probs=True - ) # B x V + .squeeze(2) + .squeeze(1) + ) # B x 1 x 1 x V -> B x V + lprobs = self.model.get_normalized_probs( + (logits.div_(self.temperature), None), log_probs=True + ) # B x V - k_expanded_hyps_nonblank.scores += lprobs[:, self.blank] - k_expanded_hyps_nonblank.prev_tokens.fill_( - self.blank - ) # unnecessary but conceptually should do - k_expanded_hyps_nonblank.num_emissions += 1 + k_expanded_hyps_nonblank.scores += lprobs[:, self.blank] + k_expanded_hyps_nonblank.prev_tokens.fill_( + self.blank + ) # unnecessary but conceptually should do + k_expanded_hyps_nonblank.num_emissions += 1 - next_step_hyps = k_expanded_hyps_blank.combine( - k_expanded_hyps_nonblank, pad_idx=self.pad - ) - next_step_hyps.keep_top_k_( - self.beam_size, normalize_by_length=self.normalize_scores - ) + next_step_hyps = k_expanded_hyps_blank.combine( + k_expanded_hyps_nonblank, pad_idx=self.pad + ) + next_step_hyps.keep_top_k_( + self.beam_size, normalize_by_length=self.normalize_scores + ) - next_step_hyps.sort_by_score_( - descending=True, normalize_by_length=self.normalize_scores - ) + # normalize scores by sequence length and sort final hyps + next_step_hyps.scores.div_(next_step_hyps.sequence_lengths - 1) # B + next_step_hyps.sort_by_score_(descending=True) # get the N-best hypotheses, and exclude the leading EOS token from the sequences sequences = next_step_hyps.sequences[:, 1:] # B x U - scores = next_step_hyps.scores / (next_step_hyps.sequence_lengths - 1) # B if self.print_alignment: alignments = next_step_hyps.alignments # B x T x N else: alignments = None - return sequences, scores, alignments + return sequences, next_step_hyps.scores, alignments def prefix_search_and_merge( self, hyps: Hypotheses, enc_out: Tensor, alpha: Optional[int] = None @@ -481,7 +483,7 @@ def prefix_search_and_merge( ).squeeze( 0 ) # 1 x V -> V - token_index = hyps.sequences[j][len_i] + token_index = hyps.sequences[j, len_i] score = hyps.scores[i] + lprobs[token_index] if self.lm_model is not None: @@ -530,7 +532,7 @@ def prefix_search_and_merge( ).squeeze( 0 ) # 1 x V -> V - token_index = hyps.sequences[j][k + 1] + token_index = hyps.sequences[j, k + 1] score += lprobs[token_index] if self.lm_model is not None: diff --git a/espresso/tools/transducer_greedy_decoder.py b/espresso/tools/transducer_greedy_decoder.py index 8db9932e8..8790cc173 100644 --- a/espresso/tools/transducer_greedy_decoder.py +++ b/espresso/tools/transducer_greedy_decoder.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch from torch import Tensor @@ -32,8 +32,8 @@ def __init__( Args: models (List[~fairseq.models.FairseqModel]): ensemble of models dictionary (~fairseq.data.Dictionary): dictionary - max_len (int, optional): the maximum length of the generated output - (not including end-of-sentence) (default: 0, no limit) + max_len (int, optional): the maximum length of the encoder output + that can emit tokens (default: 0, no limit) max_num_expansions_per_step (int, optional): the maximum number of non-blank expansions in a single time step (default: 2) temperature (float, optional): temperature, where values @@ -76,7 +76,7 @@ def __init__( @torch.no_grad() def _generate( self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None - ): + ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: net_input = sample["net_input"] src_tokens = net_input["src_tokens"] bsz, src_len = src_tokens.size()[:2] diff --git a/espresso/tools/transducer_utils.py b/espresso/tools/transducer_utils.py index dd3bab9e7..c0ae01eaf 100644 --- a/espresso/tools/transducer_utils.py +++ b/espresso/tools/transducer_utils.py @@ -15,7 +15,7 @@ @dataclass class Hypotheses: - """Hypotheses class for beam search algorithms. data from multiple hypotheses are + """Hypotheses class for beam search algorithms. Data from multiple hypotheses are stacked along the batch dimension of each attribute tensor. scores (Tensor): scores of hypotheses (including weighted LM scores if LM is present). @@ -385,7 +385,7 @@ def update_dec_out_( def get_last_dec_out(self) -> Tuple[Optional[Tensor], Optional[Tensor]]: """Returns the last `dec_out`/`lm_dec_out`, which is the output after feeding the - last non-blank tokens in `self.sequences` into the decoder/LM. + last non-blank tokens from `self.sequences` into the decoder/LM. Note: this function will NOT modify this instance. Returns: @@ -654,9 +654,10 @@ def select_k_expansions( """Returns K hypotheses candidates for expansions from a set of hypotheses. K candidates are selected according to the extended hypotheses probabilities and a prune-by-value method. Where K is equal to beam_size + beta. - Note: This function should be followed with :func:`~Hypotheses.update_dec_out_()` in the calling code - to also update `k_expanded_hyps_nonblank.cached_state`, `k_expanded_hyps.dec_out`, - `k_expanded_hyps_nonblank.lm_cached_state`, and `k_expanded_hyps.lm_dec_out` after non-blank expansions. + Note: This function should be followed by updating `k_expanded_hyps_nonblank.cached_state`, + `k_expanded_hyps.dec_out` (with :func:`~Hypotheses.update_dec_out_()`), `k_expanded_hyps_nonblank.lm_cached_state`, + and `k_expanded_hyps.lm_dec_out` (also with :func:`~Hypotheses.update_dec_out_()`) in the calling code + after non-blank expansions. This implementation is modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transducer/utils.py @@ -684,9 +685,9 @@ def select_k_expansions( """ assert hyps.size() > 0 lprobs = lprobs + hyps.scores.unsqueeze(-1) # B x V - K = min(beam_size + beta, lprobs.size(1) - 1) # -1 so we never select pad - scores, indices = torch.topk(lprobs, k=K) # B x K - k_expanded_hyps = hyps.repeat_interleave(K) # (B * K) hypotheses + k = min(beam_size + beta, lprobs.size(1) - 1) # -1 so we never select pad + scores, indices = torch.topk(lprobs, k=k) # B x K + k_expanded_hyps = hyps.repeat_interleave(k) # (B * K) hypotheses k_expanded_hyps.scores = scores.view(-1) # (B * K) if lm_lprobs_padded is not None: assert lm_lprobs_padded.size() == lprobs.size() @@ -704,7 +705,7 @@ def select_k_expansions( if not retained_mask.all(): # prune by value k_expanded_hyps = k_expanded_hyps.masked_select(retained_mask.view(-1)) - k_expanded_hyps.keep_top_k_(K, normalize_by_length=normalize_by_length) + k_expanded_hyps.keep_top_k_(k, normalize_by_length=normalize_by_length) return k_expanded_hyps