Skip to content

Commit

Permalink
transducer decoding fix; cosmatic changes
Browse files Browse the repository at this point in the history
  • Loading branch information
freewym committed Oct 29, 2022
1 parent 1e2934b commit db4eeb2
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 113 deletions.
2 changes: 1 addition & 1 deletion espresso/speech_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def decode_fn(x):
if save_attention_plot:
logger.info("Saved attention plots in " + save_dir)

if has_target:
if has_target and hasattr(task.datasets[cfg.dataset.gen_subset], "tgt"):
scorer.add_ordered_utt_list(task.datasets[cfg.dataset.gen_subset].tgt.utt_ids)

fn = "decoded_char_results.txt"
Expand Down
6 changes: 5 additions & 1 deletion espresso/tools/compute_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def main(args):
wer_counter = Counter()
with open(args.hyp_text, "r", encoding="utf-8") as f:
for line in f:
utt_id, text = line.strip().split(None, 1)
res = line.strip().split(None, 1)
if len(res) == 2:
utt_id, text = res
else:
utt_id, text = res[0], ""
assert utt_id in refs, utt_id
ref, hyp = refs[utt_id], text

Expand Down
32 changes: 26 additions & 6 deletions espresso/tools/transducer_base_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -33,8 +33,8 @@ def __init__(
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
dictionary (~fairseq.data.Dictionary): dictionary
max_len (int, optional): the maximum length of the generated output
(not including end-of-sentence) (default: 0, no limit)
max_len (int, optional): the maximum length of the encoder output
that can emit tokens (default: 0, no limit)
max_num_expansions_per_step (int, optional): the maximum number of
non-blank expansions in a single time step (default: 2)
temperature (float, optional): temperature, where values
Expand Down Expand Up @@ -95,7 +95,9 @@ def cuda(self):
return self

@torch.no_grad()
def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
def decode(
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""Generate a batch of 1-best hypotheses. Match the API of other fairseq generators.
Normally called for validation during training.
Expand Down Expand Up @@ -139,6 +141,24 @@ def generate(
@torch.no_grad()
def _generate(
self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
# should return a tuple of tokens, scores and alignments
) -> Tuple[
Union[Tensor, List[Tensor]],
Union[Tensor, List[Tensor]],
Optional[Union[Tensor, List[Tensor]]],
]:
"""Implement the algorithm here.
Should return a tuple of tokens, scores and alignments.
Args:
feature (Tensor): feature of shape
`(batch, feature_length, feature_dim)`
feature_lens (Tensor, optional): feature lengths of shape `(batch)`
Returns:
tokens (LongTensor or List[LongTensor]): token sequences of shape
`(batch, max_dec_out_length)`
scores (FloatTensor or List[FloatTensor]): scores of shape `(batch)`
alignments (LongTensor or List[LongTensor], optional): alignments of
shape `(batch, max_enc_out_length, max_num_tokens_per_step)`
"""
raise NotImplementedError
186 changes: 94 additions & 92 deletions espresso/tools/transducer_beam_search_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -46,8 +46,8 @@ def __init__(
models (List[~fairseq.models.FairseqModel]): ensemble of models
dictionary (~fairseq.data.Dictionary): dictionary
beam_size (int, optional): beam width (default: 1)
max_len (int, optional): the maximum length of the generated output
(not including end-of-sentence) (default: 0, no limit)
max_len (int, optional): the maximum length of the encoder output
that can emit tokens (default: 0, no limit)
max_num_expansions_per_step (int, optional): the maximum number of
non-blank expansions in a single time step (default: 2)
expansion_beta (int, optional): maximum number of prefix expansions allowed,
Expand All @@ -63,8 +63,8 @@ def __init__(
prefix_alpha (int, optional): maximum prefix length in prefix search.
Must be an integer, and is advised to keep this as 1 in order to
reduce expensive beam search cost later (default: 1)
normalize_scores (bool, optional): normalize scores by the length
of the output including blank (default: True)
normalize_scores (bool, optional): if True normalize scores by the length
of the output including blank when sorting/ranking hyps (default: True)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
Expand Down Expand Up @@ -101,7 +101,9 @@ def __init__(
self.normalize_scores = normalize_scores

@torch.no_grad()
def decode(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
def decode(
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
assert not self.print_alignment

tokens_list, scores_list, _ = self._generate(sample, **kwargs)
Expand Down Expand Up @@ -149,7 +151,7 @@ def generate(
@torch.no_grad()
def _generate(
self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None
):
) -> Tuple[List[Tensor], List[Tensor], Optional[List[Tensor]]]:
net_input = sample["net_input"]
src_tokens = net_input["src_tokens"]
bsz, src_len = src_tokens.size()[:2]
Expand All @@ -158,14 +160,16 @@ def _generate(
encoder_outs = self.model.encoder.forward_torchscript(net_input)
enc_out = encoder_outs["encoder_out"][0].transpose(0, 1) # B x T x C
enc_out_lengths = encoder_outs["src_lengths"][0] # B
sequences_list, scores_list, alignments_list = [], [], []
sequences_list, scores_list = [], []
alignments_list = [] if self.print_alignment else None
for i in range(bsz):
sequences, scores, alignments = self._generate_one_example(
enc_out[i : i + 1, :, :], enc_out_lengths[i], bos_token=bos_token
)
sequences_list.append(sequences)
scores_list.append(scores)
alignments_list.append(alignments)
if self.print_alignment:
alignments_list.append(alignments)

return sequences_list, scores_list, alignments_list

Expand All @@ -185,7 +189,7 @@ def _generate_one_example(
if self.print_alignment:
# token alignments for the final best hypthesis. +1 for blank
alignments = enc_out.new_full(
(enc_out_length, self.max_num_expansions_per_step + 1),
(1, enc_out_length, self.max_num_expansions_per_step + 1),
self.pad,
dtype=torch.long,
) # T x N (max num tokens per time step)
Expand Down Expand Up @@ -336,102 +340,100 @@ def _generate_one_example(
self.beam_size, normalize_by_length=self.normalize_scores
)
break
else:
# forward the decoder with `k_expanded_hyps_nonblank`
self.model.decoder.set_incremental_state(
incremental_state,
# forward the decoder with `k_expanded_hyps_nonblank`
self.model.decoder.set_incremental_state(
incremental_state,
"cached_state",
k_expanded_hyps_nonblank.cached_state,
)
dec_out = self.model.decoder.extract_features(
k_expanded_hyps_nonblank.prev_tokens.unsqueeze(1),
incremental_state=incremental_state,
)[
0
] # B x 1 x H
k_expanded_hyps_nonblank.cached_state = (
self.model.decoder.get_incremental_state(
incremental_state, "cached_state"
)
) # update `cached_state`

if self.lm_model is not None:
self.lm_model.decoder.set_incremental_state(
lm_incremental_state,
"cached_state",
k_expanded_hyps_nonblank.cached_state,
k_expanded_hyps_nonblank.lm_cached_state,
)
dec_out = self.model.decoder.extract_features(
k_expanded_hyps_nonblank.prev_tokens.unsqueeze(1),
incremental_state=incremental_state,
prev_tokens = k_expanded_hyps_nonblank.prev_tokens
lm_prev_tokens = (
torch.where(
prev_tokens > self.blank, prev_tokens - 1, prev_tokens
)
if self.no_blank_in_lm
else prev_tokens
)
lm_dec_out = self.lm_model.extract_features(
lm_prev_tokens.unsqueeze(1),
incremental_state=lm_incremental_state,
)[
0
] # B x 1 x H
k_expanded_hyps_nonblank.cached_state = (
self.model.decoder.get_incremental_state(
incremental_state, "cached_state"
] # B x 1 x H'
k_expanded_hyps_nonblank.lm_cached_state = (
self.lm_model.decoder.get_incremental_state(
lm_incremental_state, "cached_state"
)
) # update `cached_state`

if self.lm_model is not None:
self.lm_model.decoder.set_incremental_state(
lm_incremental_state,
"cached_state",
k_expanded_hyps_nonblank.lm_cached_state,
)
prev_tokens = k_expanded_hyps_nonblank.prev_tokens
lm_prev_tokens = (
torch.where(
prev_tokens > self.blank, prev_tokens - 1, prev_tokens
)
if self.no_blank_in_lm
else prev_tokens
) # update `lm_cached_state`
else:
lm_dec_out = None

k_expanded_hyps_nonblank.update_dec_out_(
dec_out, lm_dec_out=lm_dec_out
) # update `dec_out` and `lm_dec_out`

if (
expansion_idx < self.max_num_expansions_per_step - 1
): # not the last round of expansion within this time step
# prepare for the next round of expansion within this time step
hyps = k_expanded_hyps_nonblank
else: # the last round of expansion within this time step
# add blank probability to non-blank hyps, combine and prune the hyps for the next time step
logits = (
self.model.joint(
enc_out_this_step.expand(dec_out.size(0), -1, -1),
dec_out,
apply_output_layer=True,
)
lm_dec_out = self.lm_model.extract_features(
lm_prev_tokens.unsqueeze(1),
incremental_state=lm_incremental_state,
)[
0
] # B x 1 x H'
k_expanded_hyps_nonblank.lm_cached_state = (
self.lm_model.decoder.get_incremental_state(
lm_incremental_state, "cached_state"
)
) # update `lm_cached_state`
else:
lm_dec_out = None

k_expanded_hyps_nonblank.update_dec_out_(
dec_out, lm_dec_out=lm_dec_out
) # update `dec_out` and `lm_dec_out`

if (
expansion_idx < self.max_num_expansions_per_step - 1
): # not the last round of expansion within this time step
# prepare for the next round of expansion within this time step
hyps = k_expanded_hyps_nonblank
else: # the last round of expansion within this time step
# add blank probability to non-blank hyps, combine and prune the hyps for the next time step
logits = (
self.model.joint(
enc_out_this_step.expand(dec_out.size(0), -1, -1),
dec_out,
apply_output_layer=True,
)
.squeeze(2)
.squeeze(1)
) # B x 1 x 1 x V -> B x V
lprobs = self.model.get_normalized_probs(
(logits.div_(self.temperature), None), log_probs=True
) # B x V
.squeeze(2)
.squeeze(1)
) # B x 1 x 1 x V -> B x V
lprobs = self.model.get_normalized_probs(
(logits.div_(self.temperature), None), log_probs=True
) # B x V

k_expanded_hyps_nonblank.scores += lprobs[:, self.blank]
k_expanded_hyps_nonblank.prev_tokens.fill_(
self.blank
) # unnecessary but conceptually should do
k_expanded_hyps_nonblank.num_emissions += 1
k_expanded_hyps_nonblank.scores += lprobs[:, self.blank]
k_expanded_hyps_nonblank.prev_tokens.fill_(
self.blank
) # unnecessary but conceptually should do
k_expanded_hyps_nonblank.num_emissions += 1

next_step_hyps = k_expanded_hyps_blank.combine(
k_expanded_hyps_nonblank, pad_idx=self.pad
)
next_step_hyps.keep_top_k_(
self.beam_size, normalize_by_length=self.normalize_scores
)
next_step_hyps = k_expanded_hyps_blank.combine(
k_expanded_hyps_nonblank, pad_idx=self.pad
)
next_step_hyps.keep_top_k_(
self.beam_size, normalize_by_length=self.normalize_scores
)

next_step_hyps.sort_by_score_(
descending=True, normalize_by_length=self.normalize_scores
)
# normalize scores by sequence length and sort final hyps
next_step_hyps.scores.div_(next_step_hyps.sequence_lengths - 1) # B
next_step_hyps.sort_by_score_(descending=True)
# get the N-best hypotheses, and exclude the leading EOS token from the sequences
sequences = next_step_hyps.sequences[:, 1:] # B x U
scores = next_step_hyps.scores / (next_step_hyps.sequence_lengths - 1) # B
if self.print_alignment:
alignments = next_step_hyps.alignments # B x T x N
else:
alignments = None

return sequences, scores, alignments
return sequences, next_step_hyps.scores, alignments

def prefix_search_and_merge(
self, hyps: Hypotheses, enc_out: Tensor, alpha: Optional[int] = None
Expand Down Expand Up @@ -481,7 +483,7 @@ def prefix_search_and_merge(
).squeeze(
0
) # 1 x V -> V
token_index = hyps.sequences[j][len_i]
token_index = hyps.sequences[j, len_i]
score = hyps.scores[i] + lprobs[token_index]

if self.lm_model is not None:
Expand Down Expand Up @@ -530,7 +532,7 @@ def prefix_search_and_merge(
).squeeze(
0
) # 1 x V -> V
token_index = hyps.sequences[j][k + 1]
token_index = hyps.sequences[j, k + 1]
score += lprobs[token_index]

if self.lm_model is not None:
Expand Down
8 changes: 4 additions & 4 deletions espresso/tools/transducer_greedy_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Optional
from typing import Dict, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -32,8 +32,8 @@ def __init__(
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
dictionary (~fairseq.data.Dictionary): dictionary
max_len (int, optional): the maximum length of the generated output
(not including end-of-sentence) (default: 0, no limit)
max_len (int, optional): the maximum length of the encoder output
that can emit tokens (default: 0, no limit)
max_num_expansions_per_step (int, optional): the maximum number of
non-blank expansions in a single time step (default: 2)
temperature (float, optional): temperature, where values
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
@torch.no_grad()
def _generate(
self, sample: Dict[str, Dict[str, Tensor]], bos_token: Optional[int] = None
):
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
net_input = sample["net_input"]
src_tokens = net_input["src_tokens"]
bsz, src_len = src_tokens.size()[:2]
Expand Down
Loading

0 comments on commit db4eeb2

Please sign in to comment.