diff --git a/espresso/criterions/ctc_loss.py b/espresso/criterions/ctc_loss.py index ef8453106..f48e6ade9 100644 --- a/espresso/criterions/ctc_loss.py +++ b/espresso/criterions/ctc_loss.py @@ -12,10 +12,11 @@ import torch.nn.functional as F from omegaconf import II -from fairseq import metrics, utils +from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.data import data_utils from fairseq.dataclass import FairseqDataclass +from fairseq.logging import metrics from fairseq.tasks import FairseqTask logger = logging.getLogger(__name__) diff --git a/espresso/criterions/transducer_loss.py b/espresso/criterions/transducer_loss.py index 5ea368911..553c699f8 100644 --- a/espresso/criterions/transducer_loss.py +++ b/espresso/criterions/transducer_loss.py @@ -11,10 +11,11 @@ import torch from omegaconf import II -from fairseq import metrics, utils +from fairseq import utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.data import data_utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.logging import metrics from fairseq.tasks import FairseqTask logger = logging.getLogger(__name__) @@ -36,6 +37,7 @@ class TransducerLossCriterionConfig(FairseqDataclass): default="torchaudio", metadata={"help": "choice of loss backend (native or torchaudio)"}, ) + include_eos: bool = II("task.include_eos_in_transducer_loss") @register_criterion("transducer_loss", dataclass=TransducerLossCriterionConfig) @@ -64,6 +66,7 @@ def __init__(self, cfg: TransducerLossCriterionConfig, task: FairseqTask): ) self.rnnt_loss = rnnt_loss + self.include_eos = cfg.include_eos self.dictionary = task.target_dictionary self.prev_num_updates = -1 @@ -73,13 +76,15 @@ def forward(self, model, sample, reduce=True): ) # B x T x U x V, B if "target_lengths" in sample: - target_lengths = ( - sample["target_lengths"].int() - 1 - ) # Note: ensure EOS is excluded + target_lengths = sample["target_lengths"].int() + if not self.include_eos: + target_lengths -= 1 # excludes EOS else: target_lengths = ( ( (sample["target"] != self.pad_idx) + if self.include_eos + else (sample["target"] != self.pad_idx) & (sample["target"] != self.eos_idx) ) .sum(-1) @@ -124,7 +129,9 @@ def forward(self, model, sample, reduce=True): loss = self.rnnt_loss( net_output, - sample["target"][:, :-1].int().contiguous(), # exclude the last EOS column + (sample["target"] if self.include_eos else sample["target"][:, :-1]) + .int() + .contiguous(), encoder_out_lengths.int(), target_lengths, blank=self.blank_idx, diff --git a/espresso/data/asr_dataset.py b/espresso/data/asr_dataset.py index a6e186075..8ee7b6410 100644 --- a/espresso/data/asr_dataset.py +++ b/espresso/data/asr_dataset.py @@ -21,6 +21,7 @@ def collate( left_pad_source=True, left_pad_target=False, input_feeding=True, + maybe_bos_idx=None, pad_to_length=None, pad_to_multiple=1, src_bucketed=False, @@ -89,11 +90,16 @@ def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): prev_output_tokens = merge( "target", left_pad=left_pad_target, - move_eos_to_beginning=True, + move_eos_to_beginning=(maybe_bos_idx is None), pad_to_length=pad_to_length["target"] if pad_to_length is not None else None, ) + if maybe_bos_idx is not None: + all_bos_vec = prev_output_tokens.new_full((1, 1), maybe_bos_idx).expand( + len(samples), 1 + ) + prev_output_tokens = torch.cat([all_bos_vec, prev_output_tokens], dim=1) else: ntokens = src_lengths.sum().item() @@ -148,6 +154,10 @@ class AsrDataset(FairseqDataset): (default: True). input_feeding (bool, optional): create a shifted version of the targets to be passed into the model for teacher forcing (default: True). + prepend_bos_as_input_feeding (bool, optional): target prepended with BOS symbol + (instead of moving EOS to the beginning of that) as input feeding. This is + currently only for a transducer model training setting where EOS is retained + in target when evaluating the loss (default: False). constraints (Tensor, optional): 2d tensor with a concatenated, zero- delimited list of constraints for each sentence. num_buckets (int, optional): if set to a value greater than 0, then @@ -176,6 +186,7 @@ def __init__( left_pad_target=False, shuffle=True, input_feeding=True, + prepend_bos_as_input_feeding=False, constraints=None, num_buckets=0, src_lang_id=None, @@ -193,6 +204,7 @@ def __init__( self.left_pad_target = left_pad_target self.shuffle = shuffle self.input_feeding = input_feeding + self.prepend_bos_as_input_feeding = prepend_bos_as_input_feeding self.constraints = constraints self.src_lang_id = src_lang_id self.tgt_lang_id = tgt_lang_id @@ -334,6 +346,9 @@ def collater(self, samples, pad_to_length=None): left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, input_feeding=self.input_feeding, + maybe_bos_idx=self.dictionary.bos() + if self.prepend_bos_as_input_feeding + else None, pad_to_length=pad_to_length, pad_to_multiple=self.pad_to_multiple, src_bucketed=(self.buckets is not None), diff --git a/espresso/data/feat_text_dataset.py b/espresso/data/feat_text_dataset.py index eb38457a9..d7d9197db 100644 --- a/espresso/data/feat_text_dataset.py +++ b/espresso/data/feat_text_dataset.py @@ -57,7 +57,7 @@ def __init__( ): super().__init__() assert len(utt_ids) == len(rxfiles) - self.dtype = np.float + self.dtype = float self.utt_ids = utt_ids self.rxfiles = rxfiles self.size = len(utt_ids) # number of utterances @@ -338,7 +338,6 @@ def __init__( self, utt_ids: List[str], texts: List[str], dictionary=None, append_eos=True ): super().__init__() - self.dtype = np.float self.dictionary = dictionary self.append_eos = append_eos self.read_text(utt_ids, texts, dictionary) diff --git a/espresso/models/external_language_model.py b/espresso/models/external_language_model.py index ffb168356..4c9ab800f 100644 --- a/espresso/models/external_language_model.py +++ b/espresso/models/external_language_model.py @@ -9,8 +9,9 @@ from espresso.data import AsrDictionary from espresso.tools.lexical_prefix_tree import lexical_prefix_tree -from espresso.tools.utils import clone_cached_state, tokenize +from espresso.tools.utils import tokenize from fairseq.models import FairseqIncrementalDecoder, FairseqLanguageModel +from fairseq.utils import apply_to_sample class RawOutExternalLanguageModelBase(FairseqLanguageModel): @@ -125,8 +126,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): ).unsqueeze( -1 ) # B x 1 - old_cached_state = clone_cached_state( - self.lm_decoder.get_cached_state(incremental_state) + old_cached_state = apply_to_sample( + torch.clone, + self.lm_decoder.get_cached_state(incremental_state), ) # recompute cumsum_probs from inter-word transition probabilities # only for those whose prev_output_token is @@ -432,8 +434,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): ).unsqueeze( -1 ) # B x 1 - old_wordlm_cached_state = clone_cached_state( - self.wordlm_decoder.get_cached_state(incremental_state) + old_wordlm_cached_state = apply_to_sample( + torch.clone, + self.wordlm_decoder.get_cached_state(incremental_state), ) # recompute wordlm_logprobs from inter-word transition probabilities diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index 6e8e322c9..c0e075fa2 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -1017,28 +1017,16 @@ def masked_copy_cached_state( src_cached_state[2], ) - def masked_copy_state(state: Optional[Tensor], src_state: Optional[Tensor]): - if state is None: - assert src_state is None - return None - else: - assert ( - state.size(0) == mask.size(0) - and src_state is not None - and state.size() == src_state.size() - ) - state[mask, ...] = src_state[mask, ...] - return state - - prev_hiddens = [ - masked_copy_state(p, src_p) - for (p, src_p) in zip(prev_hiddens, src_prev_hiddens) - ] - prev_cells = [ - masked_copy_state(p, src_p) - for (p, src_p) in zip(prev_cells, src_prev_cells) - ] - input_feed = masked_copy_state(input_feed, src_input_feed) + mask = mask.unsqueeze(1) + prev_hiddens = speech_utils.apply_to_sample_pair( + lambda x, y, z=mask: torch.where(z, x, y), src_prev_hiddens, prev_hiddens + ) + prev_cells = speech_utils.apply_to_sample_pair( + lambda x, y, z=mask: torch.where(z, x, y), src_prev_cells, prev_cells + ) + input_feed = speech_utils.apply_to_sample_pair( + lambda x, y, z=mask: torch.where(z, x, y), src_input_feed, input_feed + ) cached_state_new = torch.jit.annotate( Dict[str, Optional[Tensor]], diff --git a/espresso/models/tensorized_lookahead_language_model.py b/espresso/models/tensorized_lookahead_language_model.py index 818e1a970..bdb2592e3 100644 --- a/espresso/models/tensorized_lookahead_language_model.py +++ b/espresso/models/tensorized_lookahead_language_model.py @@ -10,8 +10,9 @@ from espresso.data import AsrDictionary from espresso.models.external_language_model import RawOutExternalLanguageModelBase from espresso.tools.tensorized_prefix_tree import TensorizedPrefixTree -from espresso.tools.utils import clone_cached_state, tokenize +from espresso.tools.utils import tokenize from fairseq.models import FairseqIncrementalDecoder, FairseqLanguageModel +from fairseq.utils import apply_to_sample class TensorizedLookaheadLanguageModel(RawOutExternalLanguageModelBase): @@ -131,8 +132,9 @@ def forward( w: torch.Tensor = self.tree.word_idx[nodes].unsqueeze(1) # Z[Batch, Len=1] w[w < 0] = self.word_unk_idx - old_cached_state = clone_cached_state( - self.lm_decoder.get_cached_state(incremental_state) + old_cached_state = apply_to_sample( + torch.clone, + self.lm_decoder.get_cached_state(incremental_state), ) # recompute cumsum_probs from inter-word transition probabilities # only for those whose prev_output_token is diff --git a/espresso/models/transformer/speech_transformer_decoder.py b/espresso/models/transformer/speech_transformer_decoder.py index 5018d362c..50b5aaef0 100644 --- a/espresso/models/transformer/speech_transformer_decoder.py +++ b/espresso/models/transformer/speech_transformer_decoder.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from torch import Tensor +import espresso.tools.utils as speech_utils from espresso.models.transformer import SpeechTransformerConfig from espresso.modules import ( RelativePositionalEmbedding, @@ -430,34 +431,23 @@ def masked_copy_cached_state( F.pad(src_p, (0, 1)) for src_p in src_prev_key_padding_mask ] - def masked_copy_state(state: Optional[Tensor], src_state: Optional[Tensor]): - if state is None: - assert src_state is None - return None - else: - assert ( - state.size(0) == mask.size(0) - and src_state is not None - and state.size() == src_state.size() - ) - state[mask, ...] = src_state[mask, ...] - return state - - prev_key = [ - masked_copy_state(p, src_p) for (p, src_p) in zip(prev_key, src_prev_key) - ] - prev_value = [ - masked_copy_state(p, src_p) - for (p, src_p) in zip(prev_value, src_prev_value) - ] + kv_mask = mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) + prev_key = speech_utils.apply_to_sample_pair( + lambda x, y, z=kv_mask: torch.where(z, x, y), src_prev_key, prev_key + ) + prev_value = speech_utils.apply_to_sample_pair( + lambda x, y, z=kv_mask: torch.where(z, x, y), src_prev_value, prev_value + ) if prev_key_padding_mask is None: prev_key_padding_mask = src_prev_key_padding_mask else: assert src_prev_key_padding_mask is not None - prev_key_padding_mask = [ - masked_copy_state(p, src_p) - for (p, src_p) in zip(prev_key_padding_mask, src_prev_key_padding_mask) - ] + pad_mask = mask.unsqueeze(1) + prev_key_padding_mask = speech_utils.apply_to_sample_pair( + lambda x, y, z=pad_mask: torch.where(z, x, y), + src_prev_key_padding_mask, + prev_key_padding_mask, + ) cached_state = torch.jit.annotate( Dict[str, Optional[Tensor]], diff --git a/espresso/models/transformer/speech_transformer_encoder.py b/espresso/models/transformer/speech_transformer_encoder.py index e9cea6cbb..ba9c75bf4 100644 --- a/espresso/models/transformer/speech_transformer_encoder.py +++ b/espresso/models/transformer/speech_transformer_encoder.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from torch import Tensor import espresso.tools.utils as speech_utils from espresso.models.transformer import SpeechTransformerConfig @@ -314,7 +315,12 @@ def forward_scriptable( src_tokens, ~speech_utils.sequence_mask(src_lengths, src_tokens.size(1)), ) - has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() + has_pads = ( + torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any() + ) + # Torchscript doesn't handle bool Tensor correctly, so we need to work around. + if torch.jit.is_scripting(): + has_pads = torch.tensor(1) if has_pads else torch.tensor(0) if self.fc0 is not None: x = self.dropout_module(x) @@ -330,8 +336,9 @@ def forward_scriptable( x = self.quant_noise(x) # account for padding while computing the representation - if has_pads: - x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + x = x * ( + 1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x) + ) # B x T x C -> T x B x C x = x.transpose(0, 1) diff --git a/espresso/tasks/speech_recognition.py b/espresso/tasks/speech_recognition.py index 4fbed725e..484d6813c 100644 --- a/espresso/tasks/speech_recognition.py +++ b/espresso/tasks/speech_recognition.py @@ -76,6 +76,14 @@ class SpeechRecognitionEspressoConfig(FairseqDataclass): "adds 'prev_output_tokens' to input and appends eos to target" }, ) + include_eos_in_transducer_loss: bool = field( + default=False, + metadata={ + "help": "If True, we retain EOS in target when evaluating the transducer loss." + "It will also use target prepended with BOS (=BLANK) symbol (instead of " + "moving EOS to the beginning of that) as input feeding" + }, + ) max_num_expansions_per_step: int = field( default=2, metadata={ @@ -125,6 +133,7 @@ def get_asr_dataset_from_json( shuffle=True, pad_to_multiple=1, autoregressive=True, + prepend_bos_as_input_feeding=False, is_training_set=False, batch_based_on_both_src_tgt=False, seed=1, @@ -253,6 +262,7 @@ def get_asr_dataset_from_json( num_buckets=num_buckets, shuffle=shuffle, input_feeding=autoregressive, + prepend_bos_as_input_feeding=prepend_bos_as_input_feeding, pad_to_multiple=pad_to_multiple, batch_based_on_both_src_tgt=batch_based_on_both_src_tgt, ) @@ -311,8 +321,10 @@ def __init__( self.feat_in_channels = cfg.feat_in_channels self.extra_symbols_to_ignore = {tgt_dict.pad()} # for validation with WER if cfg.criterion_name in ["transducer_loss", "ctc_loss"]: - self.blank_symbol = tgt_dict.bos_word # reserve the bos symbol for blank - self.extra_symbols_to_ignore.add(tgt_dict.index(self.blank_symbol)) + self.blank_symbol = tgt_dict[ + tgt_dict.bos() + ] # reserve the bos symbol for blank + self.extra_symbols_to_ignore.add(tgt_dict.bos()) torch.backends.cudnn.deterministic = True # Compansate for the removel of :func:`torch.rand()` from # :func:`fairseq.distributed_utils.distributed_init()` by fairseq, @@ -406,6 +418,10 @@ def load_dataset( shuffle=(split != self.cfg.gen_subset), pad_to_multiple=self.cfg.required_seq_len_multiple, autoregressive=self.cfg.autoregressive, + prepend_bos_as_input_feeding=( + self.cfg.criterion_name == "transducer_loss" + and self.cfg.include_eos_in_transducer_loss + ), is_training_set=(split == self.cfg.train_subset), batch_based_on_both_src_tgt=(self.cfg.criterion_name == "transducer_loss"), seed=self.cfg.seed, @@ -442,6 +458,13 @@ def build_model(self, cfg: DictConfig, from_checkpoint=False): [model], self.target_dictionary, max_num_expansions_per_step=self.cfg.max_num_expansions_per_step, + bos=( + self.target_dictionary.bos() + if self.cfg.include_eos_in_transducer_loss + else self.target_dictionary.eos() + ), + blank=self.target_dictionary.index(self.blank_symbol), + model_predicts_eos=self.cfg.include_eos_in_transducer_loss, ) elif self.cfg.criterion_name == "ctc_loss": # a ctc model from espresso.tools.ctc_decoder import CTCDecoder @@ -450,7 +473,10 @@ def build_model(self, cfg: DictConfig, from_checkpoint=False): [model], self.target_dictionary, ) - else: # assume it is an attention-based encoder-decoder model + elif ( + self.cfg.criterion_name is None + or "cross_entropy" in self.cfg.criterion_name + ): # assume it is an attention-based encoder-decoder model from espresso.tools.simple_greedy_decoder import SimpleGreedyDecoder self.decoder_for_validation = SimpleGreedyDecoder( @@ -458,6 +484,8 @@ def build_model(self, cfg: DictConfig, from_checkpoint=False): self.target_dictionary, for_validation=True, ) + else: + self.decoder_for_validation = None return model @@ -505,6 +533,13 @@ def build_generator( expansion_beta=getattr(args, "transducer_expansion_beta", 0), expansion_gamma=getattr(args, "transducer_expansion_gamma", None), prefix_alpha=getattr(args, "transducer_prefix_alpha", None), + bos=( + self.target_dictionary.bos() + if self.cfg.include_eos_in_transducer_loss + else self.target_dictionary.eos() + ), + blank=self.target_dictionary.index(self.blank_symbol), + model_predicts_eos=self.cfg.include_eos_in_transducer_loss, **extra_gen_cls_kwargs, ) elif self.cfg.criterion_name == "ctc_loss": @@ -534,12 +569,13 @@ def build_generator( def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) - ( - logging_output["word_error"], - logging_output["word_count"], - logging_output["char_error"], - logging_output["char_count"], - ) = self._inference_with_wer(self.decoder_for_validation, sample, model) + if self.decoder_for_validation is not None: + ( + logging_output["word_error"], + logging_output["word_count"], + logging_output["char_error"], + logging_output["char_count"], + ) = self._inference_with_wer(self.decoder_for_validation, sample, model) return loss, sample_size, logging_output def begin_epoch(self, epoch, model): diff --git a/espresso/tools/transducer_base_decoder.py b/espresso/tools/transducer_base_decoder.py index 21dea5f29..1452e0966 100644 --- a/espresso/tools/transducer_base_decoder.py +++ b/espresso/tools/transducer_base_decoder.py @@ -22,6 +22,9 @@ def __init__( max_num_expansions_per_step=2, temperature=1.0, eos=None, + bos=None, + blank=None, + model_predicts_eos=False, symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, @@ -40,6 +43,15 @@ def __init__( temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) + eos (int, optional): index of eos. Will be dictionary.eos() if None + (default: None) + bos (int, optional): index of bos. Will be dictionary.eos() if None + (default: None) + blank (int, optional): index of blank. Will be dictionary.bos() if + None (default: None) + model_predicts_eos(bool, optional): enable it if the transducer model was + trained to predict EOS. Probability mass of emitting EOS will be transferred + to BLANK to alleviate early stop issue during decoding (default: False) lm_model (fairseq.models.FairseqLanguageModel, optional): LM model for LM fusion (default: None) lm_weight (float, optional): LM weight for LM fusion (default: 1.0) print_alignment (bool, optional): if True returns alignments (default: False) @@ -47,11 +59,15 @@ def __init__( super().__init__() self.model = models[0] # currently only support single models self.eos = dictionary.eos() if eos is None else eos - self.blank = dictionary.bos() # we make the optional BOS symbol as blank + self.bos = dictionary.eos() if bos is None else bos + self.blank = ( + dictionary.bos() if blank is None else blank + ) # we make the optional BOS symbol as blank + self.model_predicts_eos = model_predicts_eos self.symbols_to_strip_from_output = ( - symbols_to_strip_from_output.union({self.eos, self.blank}) + symbols_to_strip_from_output.union({self.eos, self.bos, self.blank}) if symbols_to_strip_from_output is not None - else {self.eos, self.blank} + else {self.eos, self.bos, self.blank} ) self.vocab_size = len(dictionary) self.beam_size = 1 # child classes can overwrite it diff --git a/espresso/tools/transducer_beam_search_decoder.py b/espresso/tools/transducer_beam_search_decoder.py index 2a996b651..b2bc28a1f 100644 --- a/espresso/tools/transducer_beam_search_decoder.py +++ b/espresso/tools/transducer_beam_search_decoder.py @@ -32,6 +32,10 @@ def __init__( normalize_scores=True, temperature=1.0, eos=None, + bos=None, + blank=None, + pad=None, + model_predicts_eos=False, symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, @@ -68,6 +72,17 @@ def __init__( temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) + eos (int, optional): index of eos. Will be dictionary.eos() if None + (default: None) + bos (int, optional): index of bos. Will be dictionary.eos() if None + (default: None) + blank (int, optional): index of blank. Will be dictionary.bos() if + None (default: None) + pad (int, optional): index of pad. Will be dictionary.pad() if None + (default: None) + model_predicts_eos(bool, optional): enable it if the transducer model was + trained to predict EOS. Probability mass of emitting EOS will be transferred + to BLANK to alleviate early stop issue during decoding (default: False) lm_model (fairseq.models.FairseqLanguageModel, optional): LM model for LM fusion (default: None) lm_weight (float, optional): LM weight for LM fusion (default: 1.0) print_alignment (bool, optional): if True returns alignments (default: False) @@ -79,15 +94,20 @@ def __init__( max_num_expansions_per_step=max_num_expansions_per_step, temperature=temperature, eos=eos, + bos=bos, + blank=blank, + model_predicts_eos=model_predicts_eos, symbols_to_strip_from_output=symbols_to_strip_from_output, lm_model=lm_model, lm_weight=lm_weight, print_alignment=print_alignment, **kwargs, ) - self.pad = dictionary.pad() + self.pad = dictionary.pad() if pad is None else pad # the max beam size is the dictionary size - 1, since we never select pad - self.beam_size = min(beam_size, self.vocab_size - 1) + self.beam_size = min( + beam_size, self.vocab_size - (1 if self.pad != self.blank else 0) + ) self.expansion_beta = expansion_beta assert expansion_beta >= 0, "--expansion-beta must be non-negative" self.expansion_gamma = expansion_gamma @@ -183,7 +203,7 @@ def _generate_one_example( # prev_tokens stores the previous tokens to be fed into the decoder prev_tokens = enc_out.new_full( - (1,), self.eos if bos_token is None else bos_token, dtype=torch.long + (1,), self.bos if bos_token is None else bos_token, dtype=torch.long ) # B(=1) if self.print_alignment: @@ -311,6 +331,13 @@ def _generate_one_example( else: lm_lprobs_padded = None + if self.model_predicts_eos: + # merge blank prob and EOS prob and set EOS prob to 0 to mitigate large del errors + lprobs[:, self.blank] = torch.logaddexp( + lprobs[:, self.blank], lprobs[:, self.eos] + ) + lprobs[:, self.eos] = float("-inf") + # compute k expansions for all the current hypotheses k_expanded_hyps = select_k_expansions( hyps, diff --git a/espresso/tools/transducer_greedy_decoder.py b/espresso/tools/transducer_greedy_decoder.py index 8790cc173..08a1f51c6 100644 --- a/espresso/tools/transducer_greedy_decoder.py +++ b/espresso/tools/transducer_greedy_decoder.py @@ -9,7 +9,7 @@ from torch import Tensor from espresso.tools.transducer_base_decoder import TransducerBaseDecoder -from espresso.tools.utils import clone_cached_state +from fairseq.utils import apply_to_sample class TransducerGreedyDecoder(TransducerBaseDecoder): @@ -21,6 +21,9 @@ def __init__( max_num_expansions_per_step=2, temperature=1.0, eos=None, + bos=None, + blank=None, + model_predicts_eos=False, symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, @@ -39,6 +42,15 @@ def __init__( temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) + eos (int, optional): index of eos. Will be dictionary.eos() if None + (default: None) + bos (int, optional): index of bos. Will be dictionary.eos() if None + (default: None) + blank (int, optional): index of blank. Will be dictionary.bos() if + None (default: None) + model_predicts_eos(bool, optional): enable it if the transducer model was + trained to predict EOS. Probability mass of emitting EOS will be transferred + to BLANK to alleviate early stop issue during decoding (default: False) lm_model (fairseq.models.FairseqLanguageModel, optional): LM model for LM fusion (default: None) lm_weight (float, optional): LM weight for LM fusion (default: 1.0) print_alignment (bool, optional): if True returns alignments (default: False) @@ -50,6 +62,9 @@ def __init__( max_num_expansions_per_step=max_num_expansions_per_step, temperature=temperature, eos=eos, + bos=bos, + blank=blank, + model_predicts_eos=model_predicts_eos, symbols_to_strip_from_output=symbols_to_strip_from_output, lm_model=lm_model, lm_weight=lm_weight, @@ -99,7 +114,7 @@ def _generate( dtype=torch.long, ) # +1 for the last blank at each time step prev_nonblank_tokens = tokens.new_full( - (bsz, 1), self.eos if bos_token is None else bos_token + (bsz, 1), self.bos if bos_token is None else bos_token ) # B x 1 # scores is used to store log-prob of emitting each token scores = enc_out.new_full( @@ -133,8 +148,9 @@ def _generate( not blank_mask.all() and expansion_idx < self.max_num_expansions_per_step + 1 ): - old_cached_state = clone_cached_state( - self.model.decoder.get_cached_state(incremental_state) + old_cached_state = apply_to_sample( + torch.clone, + self.model.decoder.get_cached_state(incremental_state), ) dec_out = self.model.decoder.extract_features( prev_nonblank_tokens, incremental_state=incremental_state @@ -153,8 +169,9 @@ def _generate( ) # B x V if self.lm_model is not None: - old_lm_cached_state = clone_cached_state( - self.lm_model.decoder.get_cached_state(lm_incremental_state) + old_lm_cached_state = apply_to_sample( + torch.clone, + self.lm_model.decoder.get_cached_state(lm_incremental_state), ) lm_prev_nonblank_tokens = ( torch.where( @@ -192,6 +209,13 @@ def _generate( ) # B x (V - 1) lprobs[:, self.vocab_nonblank_mask] = lprobs_with_lm_no_blank + if self.model_predicts_eos: + # merge blank prob and EOS prob and set EOS prob to 0 to mitigate large del errors + lprobs[:, self.blank] = torch.logaddexp( + lprobs[:, self.blank], lprobs[:, self.eos] + ) + lprobs[:, self.eos] = float("-inf") + if expansion_idx < self.max_num_expansions_per_step: ( scores[:, step, expansion_idx], diff --git a/espresso/tools/transducer_utils.py b/espresso/tools/transducer_utils.py index c0ae01eaf..e5f69cc8c 100644 --- a/espresso/tools/transducer_utils.py +++ b/espresso/tools/transducer_utils.py @@ -103,8 +103,8 @@ def index_select_(self, index: Tensor) -> Hypotheses: def sort_by_score_( self, - descending: Optional[bool] = False, - normalize_by_length: Optional[bool] = False, + descending: bool = False, + normalize_by_length: bool = False, ) -> Hypotheses: """Sorts the hypotheses in ascending/descending order of their scores which are optionally normalized by sequence length. @@ -129,7 +129,7 @@ def sort_by_score_( return self.index_select_(sort_order) - def sort_by_length_(self, descending: Optional[bool] = False) -> Hypotheses: + def sort_by_length_(self, descending: bool = False) -> Hypotheses: """Sorts the hypotheses in ascending/descending order of the predicted sequence lengths. Note: this function will modify this instance. @@ -163,9 +163,9 @@ def keep_first_k_(self, k: int) -> Hypotheses: def keep_top_k_( self, k: int, - largest: Optional[bool] = True, - sorted: Optional[bool] = True, - normalize_by_length: Optional[bool] = False, + largest: bool = True, + sorted: bool = True, + normalize_by_length: bool = False, ) -> Hypotheses: """Keeps the k-best hypotheses of this instance based on their scores and discards the rest. This function is usually faster than @@ -281,7 +281,7 @@ def append_tokens_( time_step: int, expansion_idx: int, blank_idx: int, - pad_idx: Optional[int] = 0, + pad_idx: int = 0, ) -> Hypotheses: """Appends non-blank tokens in `tokens` to `self.sequences`, allocates additional memory for `self.dec_out` and `self.lm_dec_out` if needed (which will later be updated by :func:`~Hypotheses.update_dec_out_()`), @@ -489,9 +489,7 @@ def repeat_interleave(self, repeats: int) -> Hypotheses: lm_dec_out=lm_dec_out, ) - def combine( - self, another_hyps: Hypotheses, pad_idx: Optional[int] = 0 - ) -> Hypotheses: + def combine(self, another_hyps: Hypotheses, pad_idx: int = 0) -> Hypotheses: """Returns a new instance of :class:`~Hypotheses` where it combines hypotheses from this instance and `another_hyps`. It does tensor concatenations along the batch dimension, after padding along the time dimension if needed. Note: this function will NOT modify this instance. @@ -645,11 +643,11 @@ def select_k_expansions( time_step: int, expansion_idx: int, blank_idx: int, - pad_idx: Optional[int] = 0, + pad_idx: int = 0, lm_lprobs_padded: Optional[Tensor] = None, gamma: Optional[float] = None, - beta: Optional[int] = 0, - normalize_by_length: Optional[bool] = False, + beta: int = 0, + normalize_by_length: bool = False, ) -> Hypotheses: """Returns K hypotheses candidates for expansions from a set of hypotheses. K candidates are selected according to the extended hypotheses probabilities @@ -685,7 +683,9 @@ def select_k_expansions( """ assert hyps.size() > 0 lprobs = lprobs + hyps.scores.unsqueeze(-1) # B x V - k = min(beam_size + beta, lprobs.size(1) - 1) # -1 so we never select pad + k = min( + beam_size + beta, lprobs.size(1) - (1 if pad_idx != blank_idx else 0) + ) # -1 for pad scores, indices = torch.topk(lprobs, k=k) # B x K k_expanded_hyps = hyps.repeat_interleave(k) # (B * K) hypotheses k_expanded_hyps.scores = scores.view(-1) # (B * K) @@ -720,7 +720,7 @@ def is_prefix(a: List[Union[int, str]], b: List[Union[int, str]]): return True -def is_prefix_tensorized(hyps: Hypotheses, are_sorted: Optional[bool] = False): +def is_prefix_tensorized(hyps: Hypotheses, are_sorted: bool = False): """Returns a mask tensor where the (i, j)-th element indicates if the i-th row of `hyps.sequences` is a prefix of the j-th row. diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index 4b6ac1aa0..6fbb6329d 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import collections import os import re from collections import Counter @@ -57,6 +58,42 @@ def tokenize(sent, space="", non_lang_syms=None): return " ".join(tokens) +def apply_to_sample_pair(f, sample_a, sample_b): + """Recursively applies function f to each element pair in sample pair. + + Args: + f: function to apply with + sample_a/b: can be any nest dictionary/list/tuple of tensors + + Returns: + sample: output sample preserving the same structure as input sample + pair with its elements being applied with f + """ + + def _apply(a, b): + if torch.is_tensor(a): + return f(a, b) + if isinstance(a, collections.OrderedDict): + # OrderedDict has attributes that needs to be preserved + odict = collections.OrderedDict( + (key, _apply(value, b[key])) for key, value in a.items() + ) + odict.__dict__ = a.__dict__ + return odict + if isinstance(a, dict): + return {key: _apply(value, b[key]) for key, value in a.items()} + if isinstance(a, list): + return [_apply(x, y) for x, y in zip(a, b)] + if isinstance(a, tuple): + return tuple(_apply(x, y) for x, y in zip(a, b)) + if a is None: + assert b is None + return None + return a + + return _apply(sample_a, sample_b) + + def collate_frames( values, pad_value=0.0, left_pad=False, pad_to_length=None, pad_to_multiple=1 ): @@ -388,20 +425,6 @@ def aligned_print(ref, hyp, steps): return out_str -def clone_cached_state( - cached_state: Tuple[Optional[Union[List[torch.Tensor], torch.Tensor]]] -): - if cached_state is None: - return None - - def clone_state(state): - if isinstance(state, list): - return [clone_state(state_i) for state_i in state] - return state.clone() if state is not None else None - - return tuple(map(clone_state, cached_state)) - - def get_torchaudio_fbank_or_mfcc( waveform: np.ndarray, sample_rate: float,