From 41cdb008a9c5e79393152459d0fd94b33fb47c55 Mon Sep 17 00:00:00 2001 From: freewym Date: Wed, 9 Nov 2022 20:41:23 -0500 Subject: [PATCH] makes BatchNorm configurable in preencoder; turns off elementwise_affine for layernorm_embedding in transformer encoder to stabilize training; adds the bias term to joiner's final fc_out and uses its default initializer --- espresso/models/speech_lstm.py | 1 + .../transformer/speech_transformer_config.py | 6 ++++ .../transformer/speech_transformer_encoder.py | 4 ++- .../speech_transformer_transducer_base.py | 11 ++++-- espresso/modules/speech_convolutions.py | 35 +++++++++++++------ 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/espresso/models/speech_lstm.py b/espresso/models/speech_lstm.py index c0e075fa2..ed7f22449 100644 --- a/espresso/models/speech_lstm.py +++ b/espresso/models/speech_lstm.py @@ -235,6 +235,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): kernel_sizes, strides, in_channels=task.feat_in_channels, + apply_batchnorm=True, ) if out_channels is not None else None diff --git a/espresso/models/transformer/speech_transformer_config.py b/espresso/models/transformer/speech_transformer_config.py index 396d7fc1f..6c44a7a53 100644 --- a/espresso/models/transformer/speech_transformer_config.py +++ b/espresso/models/transformer/speech_transformer_config.py @@ -52,6 +52,12 @@ class SpeechEncoderConfig(SpeechEncDecBaseConfig): default="[(1, 1), (2, 2), (1, 1), (2, 2)]", metadata={"help": "list of encoder convolution's out strides"}, ) + conv_apply_batchnorm: bool = field( + default=True, + metadata={ + "help": "whether to apply BatchNorm after each convolution layer in pre-encoder" + }, + ) transformer_context: Optional[str] = field( default=None, metadata={ diff --git a/espresso/models/transformer/speech_transformer_encoder.py b/espresso/models/transformer/speech_transformer_encoder.py index cc025be84..4110f7cd1 100644 --- a/espresso/models/transformer/speech_transformer_encoder.py +++ b/espresso/models/transformer/speech_transformer_encoder.py @@ -105,7 +105,9 @@ def __init__( ) if cfg.layernorm_embedding: - self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) + self.layernorm_embedding = LayerNorm( + embed_dim, elementwise_affine=False, export=cfg.export + ) # sets elementwise_affine to False to stabilize training else: self.layernorm_embedding = None diff --git a/espresso/models/transformer/speech_transformer_transducer_base.py b/espresso/models/transformer/speech_transformer_transducer_base.py index f024fb968..627ce0ae0 100644 --- a/espresso/models/transformer/speech_transformer_transducer_base.py +++ b/espresso/models/transformer/speech_transformer_transducer_base.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from omegaconf import DictConfig from torch import Tensor import espresso.tools.utils as speech_utils @@ -79,14 +80,12 @@ def __init__(self, cfg, encoder, decoder): self.fc_out = nn.Linear( self.decoder.embed_tokens.embedding_dim, self.decoder.embed_tokens.num_embeddings, - bias=False, ) self.fc_out.weight = self.decoder.embed_tokens.weight else: self.fc_out = nn.Linear( - cfg.joint_dim, self.decoder.embed_tokens.num_embeddings, bias=False + cfg.joint_dim, self.decoder.embed_tokens.num_embeddings ) - nn.init.normal_(self.fc_out.weight, mean=0, std=cfg.joint_dim**-0.5) self.fc_out = nn.utils.weight_norm(self.fc_out, name="weight") self.cfg = cfg @@ -144,6 +143,7 @@ def build_model(cls, cfg, task): kernel_sizes, strides, in_channels=task.feat_in_channels, + apply_batchnorm=cfg.encoder.conv_apply_batchnorm, ) if out_channels is not None else None @@ -310,3 +310,8 @@ def get_normalized_probs( ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + def prepare_for_inference_(self, cfg: DictConfig): + """Prepare model for inference.""" + self.fc_out = nn.utils.remove_weight_norm(self.fc_out) + super().prepare_for_inference_(cfg) diff --git a/espresso/modules/speech_convolutions.py b/espresso/modules/speech_convolutions.py index e83e477a3..589de91dc 100644 --- a/espresso/modules/speech_convolutions.py +++ b/espresso/modules/speech_convolutions.py @@ -20,9 +20,19 @@ class ConvBNReLU(nn.Module): - """Sequence of convolution-BatchNorm-ReLU layers.""" - - def __init__(self, out_channels, kernel_sizes, strides, in_channels=1): + """Sequence of convolution-[BatchNorm]-ReLU layers. + + Args: + out_channels (int): the number of output channels of conv layer + kernel_sizes (int or tuple): kernel sizes + strides (int or tuple): strides + in_channels (int, optional): the number of input channels (default: 1) + apply_batchnorm (bool, optional): if True apply BatchNorm after each convolution layer (default: True) + """ + + def __init__( + self, out_channels, kernel_sizes, strides, in_channels=1, apply_batchnorm=True + ): super().__init__() if not has_packaging: raise ImportError("Please install packaging with: pip install packaging") @@ -35,7 +45,7 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1): assert num_layers == len(kernel_sizes) and num_layers == len(strides) self.convolutions = nn.ModuleList() - self.batchnorms = nn.ModuleList() + self.batchnorms = nn.ModuleList() if apply_batchnorm else None for i in range(num_layers): self.convolutions.append( Convolution2d( @@ -45,7 +55,8 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1): self.strides[i], ) ) - self.batchnorms.append(nn.BatchNorm2d(out_channels[i])) + if apply_batchnorm: + self.batchnorms.append(nn.BatchNorm2d(out_channels[i])) def output_lengths(self, in_lengths: Union[torch.Tensor, int]): out_lengths = in_lengths @@ -65,18 +76,22 @@ def output_lengths(self, in_lengths: Union[torch.Tensor, int]): return out_lengths def forward(self, src, src_lengths): - # B X T X C -> B X (input channel num) x T X (C / input channel num) + # B x T x C -> B x (input channel num) x T x (C / input channel num) x = src.view( src.size(0), src.size(1), self.in_channels, src.size(2) // self.in_channels, ).transpose(1, 2) - for conv, bn in zip(self.convolutions, self.batchnorms): - x = F.relu(bn(conv(x))) - # B X (output channel num) x T X C' -> B X T X (output channel num) X C' + if self.batchnorms is not None: + for conv, bn in zip(self.convolutions, self.batchnorms): + x = F.relu(bn(conv(x))) + else: + for conv in self.convolutions: + x = F.relu(conv(x)) + # B x (output channel num) x T x C' -> B x T x (output channel num) x C' x = x.transpose(1, 2) - # B X T X (output channel num) X C' -> B X T X C + # B x T x (output channel num) x C' -> B x T x C x = x.contiguous().view(x.size(0), x.size(1), x.size(2) * x.size(3)) x_lengths = self.output_lengths(src_lengths)