diff --git a/espresso/models/transformer/speech_transformer_base.py b/espresso/models/transformer/speech_transformer_base.py index 060df3496..d6a45d16c 100644 --- a/espresso/models/transformer/speech_transformer_base.py +++ b/espresso/models/transformer/speech_transformer_base.py @@ -126,18 +126,6 @@ def build_model(cls, cfg, task): else: transformer_encoder_input_size = task.feat_dim - encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( - cfg.encoder.transformer_context, - type=int, - ) - if encoder_transformer_context is not None: - assert len(encoder_transformer_context) == 2 - for i in range(2): - assert encoder_transformer_context[i] is None or ( - isinstance(encoder_transformer_context[i], int) - and encoder_transformer_context[i] >= 0 - ) - scheduled_sampling_rate_scheduler = ScheduledSamplingRateScheduler( cfg.scheduled_sampling_probs, cfg.start_scheduled_sampling_epoch, @@ -147,7 +135,6 @@ def build_model(cls, cfg, task): cfg, pre_encoder=conv_layers, input_size=transformer_encoder_input_size, - transformer_context=encoder_transformer_context, ) decoder = cls.build_decoder( cfg, @@ -162,14 +149,9 @@ def set_num_updates(self, num_updates): super().set_num_updates(num_updates) @classmethod - def build_encoder( - cls, cfg, pre_encoder=None, input_size=83, transformer_context=None - ): + def build_encoder(cls, cfg, pre_encoder=None, input_size=83): return SpeechTransformerEncoderBase( - cfg, - pre_encoder=pre_encoder, - input_size=input_size, - transformer_context=transformer_context, + cfg, pre_encoder=pre_encoder, input_size=input_size ) @classmethod diff --git a/espresso/models/transformer/speech_transformer_config.py b/espresso/models/transformer/speech_transformer_config.py index 828f6ae71..396d7fc1f 100644 --- a/espresso/models/transformer/speech_transformer_config.py +++ b/espresso/models/transformer/speech_transformer_config.py @@ -62,6 +62,22 @@ class SpeechEncoderConfig(SpeechEncDecBaseConfig): layer_type: LAYER_TYPE_CHOICES = field( default="transformer", metadata={"help": "layer type in encoder"} ) + chunk_size: int = field( + default=0, + metadata={"help": "chunk size of Transformer in chunk streaming mode if > 0"}, + ) + chunk_left_window: int = field( + default=0, + metadata={ + "help": "number of chunks to the left of the current chunk in chunk streaming mode" + }, + ) + chunk_right_window: int = field( + default=0, + metadata={ + "help": "number of chunks to the right of the current chunk in chunk streaming mode" + }, + ) # config specific to Conformer depthwise_conv_kernel_size: int = field( default=31, diff --git a/espresso/models/transformer/speech_transformer_encoder.py b/espresso/models/transformer/speech_transformer_encoder.py index bc605f6aa..e9cea6cbb 100644 --- a/espresso/models/transformer/speech_transformer_encoder.py +++ b/espresso/models/transformer/speech_transformer_encoder.py @@ -17,6 +17,7 @@ RelativePositionalEmbedding, TransformerWithRelativePositionalEmbeddingEncoderLayerBase, ) +from fairseq.data import data_utils from fairseq.distributed import fsdp_wrap from fairseq.models.transformer import Linear, TransformerEncoderBase from fairseq.modules import ( @@ -59,7 +60,6 @@ def __init__( return_fc=False, pre_encoder=None, input_size=83, - transformer_context=None, ): self.cfg = cfg super(TransformerEncoderBase, self).__init__(None) # no src dictionary @@ -159,7 +159,19 @@ def __init__( else: self.layer_norm = None - self.transformer_context = transformer_context + self.transformer_context = speech_utils.eval_str_nested_list_or_tuple( + cfg.encoder.transformer_context, + type=int, + ) + if self.transformer_context is not None: + assert len(self.transformer_context) == 2 + for i in range(2): + assert self.transformer_context[i] is None or ( + isinstance(self.transformer_context[i], int) + and self.transformer_context[i] >= 0 + ) + + self.num_updates = 0 def build_encoder_layer( self, cfg, positional_embedding: Optional[RelativePositionalEmbedding] = None @@ -183,6 +195,10 @@ def build_encoder_layer( layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) return layer + def set_num_updates(self, num_updates): + self.num_updates = num_updates + super().set_num_updates(num_updates) + def output_lengths(self, in_lengths): return ( in_lengths @@ -204,6 +220,16 @@ def get_attn_mask(self, in_lengths): `attn_mask[tgt_i, src_j] = 1` means that when calculating the embedding for `tgt_i`, we exclude (mask out) `src_j`. """ + if self.cfg.encoder.chunk_size > 0: + with data_utils.numpy_seed(self.num_updates): + return ~speech_utils.chunk_streaming_mask( + in_lengths, + self.cfg.encoder.chunk_size, + left_window=self.cfg.encoder.chunk_left_window, + right_window=self.cfg.encoder.chunk_right_window, + always_partial_in_last=(not self.training), + ) + if self.transformer_context is None or ( self.transformer_context[0] is None and self.transformer_context[1] is None ): @@ -383,7 +409,6 @@ def __init__( return_fc=False, pre_encoder=None, input_size=83, - transformer_context=None, ): self.args = args super().__init__( @@ -391,7 +416,6 @@ def __init__( return_fc=return_fc, pre_encoder=pre_encoder, input_size=input_size, - transformer_context=transformer_context, ) def build_encoder_layer( diff --git a/espresso/models/transformer/speech_transformer_encoder_model.py b/espresso/models/transformer/speech_transformer_encoder_model.py index 7d835bf5a..a862654e7 100644 --- a/espresso/models/transformer/speech_transformer_encoder_model.py +++ b/espresso/models/transformer/speech_transformer_encoder_model.py @@ -102,23 +102,10 @@ def build_model(cls, cfg, task): else: transformer_encoder_input_size = task.feat_dim - encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( - cfg.encoder.transformer_context, - type=int, - ) - if encoder_transformer_context is not None: - assert len(encoder_transformer_context) == 2 - for i in range(2): - assert encoder_transformer_context[i] is None or ( - isinstance(encoder_transformer_context[i], int) - and encoder_transformer_context[i] >= 0 - ) - encoder = cls.build_encoder( cfg, pre_encoder=conv_layers, input_size=transformer_encoder_input_size, - transformer_context=encoder_transformer_context, vocab_size=( len(task.target_dictionary) if task.target_dictionary is not None @@ -139,14 +126,12 @@ def build_encoder( cfg, pre_encoder=None, input_size=83, - transformer_context=None, vocab_size=None, ): return SpeechTransformerEncoderForPrediction( cfg, pre_encoder=pre_encoder, input_size=input_size, - transformer_context=transformer_context, vocab_size=vocab_size, ) @@ -174,7 +159,6 @@ def __init__( return_fc=False, pre_encoder=None, input_size=83, - transformer_context=None, vocab_size=None, ): super().__init__( @@ -182,7 +166,6 @@ def __init__( return_fc=return_fc, pre_encoder=pre_encoder, input_size=input_size, - transformer_context=transformer_context, ) self.fc_out = ( diff --git a/espresso/models/transformer/speech_transformer_transducer_base.py b/espresso/models/transformer/speech_transformer_transducer_base.py index 8663186df..f024fb968 100644 --- a/espresso/models/transformer/speech_transformer_transducer_base.py +++ b/espresso/models/transformer/speech_transformer_transducer_base.py @@ -165,23 +165,10 @@ def build_model(cls, cfg, task): else: transformer_encoder_input_size = task.feat_dim - encoder_transformer_context = speech_utils.eval_str_nested_list_or_tuple( - cfg.encoder.transformer_context, - type=int, - ) - if encoder_transformer_context is not None: - assert len(encoder_transformer_context) == 2 - for i in range(2): - assert encoder_transformer_context[i] is None or ( - isinstance(encoder_transformer_context[i], int) - and encoder_transformer_context[i] >= 0 - ) - encoder = cls.build_encoder( cfg, pre_encoder=conv_layers, input_size=transformer_encoder_input_size, - transformer_context=encoder_transformer_context, ) decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded @@ -206,14 +193,11 @@ def build_embedding(cls, cfg, dictionary, embed_dim, path=None): return emb @classmethod - def build_encoder( - cls, cfg, pre_encoder=None, input_size=83, transformer_context=None - ): + def build_encoder(cls, cfg, pre_encoder=None, input_size=83): return SpeechTransformerEncoderBase( cfg, pre_encoder=pre_encoder, input_size=input_size, - transformer_context=transformer_context, ) @classmethod diff --git a/espresso/tools/utils.py b/espresso/tools/utils.py index 840a10046..64efbcceb 100644 --- a/espresso/tools/utils.py +++ b/espresso/tools/utils.py @@ -12,6 +12,7 @@ import numpy as np import torch +import torch.nn.functional as F try: import kaldi_io @@ -90,6 +91,74 @@ def sequence_mask(sequence_length, max_len=None): return seq_range_expand < seq_length_expand +def chunk_streaming_mask( + sequence_length: torch.Tensor, + chunk_size: int, + left_window: int = 0, + right_window: int = 0, + always_partial_in_last: bool = False, +): + """Returns a mask for chunk streaming Transformer models. + + Args: + sequence_length (LongTensor): sequence_length of shape `(batch)` + chunk_size (int): chunk size + left_window (int): how many left chunks can be seen (default: 0) + right_window (int): how many right chunks can be seen (default: 0) + always_partial_in_last (bool): if True always makes the last chunk partial; + otherwise makes either the first or last chunk have partial size randomly, + which is to avoid learning to emit EOS just based on partial chunk size + (default: False) + + Returns: + mask: (BoolTensor): a mask tensor of shape `(tgt_len, src_len)`, where + `tgt_len` is the length of output and `src_len` is the length of input + attn_mask[tgt_i, src_j] = True` means that when calculating the embedding + for `tgt_i`, we need `src_j`. + """ + + max_len = sequence_length.data.max() + chunk_start_idx = torch.arange( + 0, + max_len, + chunk_size, + dtype=sequence_length.dtype, + device=sequence_length.device, + ) # e.g. [0,18,36,54] + if not always_partial_in_last and np.random.rand() > 0.5: + # either first or last chunk is partial. If only the last one is not complete, EOS is not effective + chunk_start_idx = max_len - chunk_start_idx + chunk_start_idx = chunk_start_idx.flip([0]) + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = F.pad(chunk_start_idx, (1, 0)) + + start_pad = torch.nn.functional.pad(chunk_start_idx, (1, 0)) # [0,0,18,36,54] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=max_len + ) # [0,18,36,54,max_len] + seq_range = torch.arange( + 0, max_len, dtype=sequence_length.dtype, device=sequence_length.device + ) + idx = ( + (seq_range.unsqueeze(-1) >= start_pad) & (seq_range.unsqueeze(-1) < end_pad) + ).nonzero()[ + :, 1 + ] # max_len + seq_range_expand = seq_range.unsqueeze(0).expand(max_len, -1) # max_len x max_len + + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] # max_len + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] # max_len + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + + return mask_left & mask_right + + def convert_padding_direction( src_frames, src_lengths,