Skip to content

Commit

Permalink
makes BatchNorm configurable in preencoder; turns off elementwise_aff…
Browse files Browse the repository at this point in the history
…ine for layernorm_embedding in transformer encoder to stabilize training; adds the bias term to joiner's final fc_out and uses its default initializer
  • Loading branch information
freewym committed Nov 10, 2022
1 parent 321982b commit 5d32606
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 14 deletions.
1 change: 1 addition & 0 deletions espresso/models/speech_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
kernel_sizes,
strides,
in_channels=task.feat_in_channels,
apply_batchnorm=True,
)
if out_channels is not None
else None
Expand Down
6 changes: 6 additions & 0 deletions espresso/models/transformer/speech_transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class SpeechEncoderConfig(SpeechEncDecBaseConfig):
default="[(1, 1), (2, 2), (1, 1), (2, 2)]",
metadata={"help": "list of encoder convolution's out strides"},
)
conv_apply_batchnorm: bool = field(
default=True,
metadata={
"help": "whether to apply BatchNorm after each convolution layer in pre-encoder"
},
)
transformer_context: Optional[str] = field(
default=None,
metadata={
Expand Down
4 changes: 3 additions & 1 deletion espresso/models/transformer/speech_transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def __init__(
)

if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
self.layernorm_embedding = LayerNorm(
embed_dim, elementwise_affine=False, export=cfg.export
) # sets elementwise_affine to False to stabilize training
else:
self.layernorm_embedding = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from torch import Tensor

import espresso.tools.utils as speech_utils
Expand Down Expand Up @@ -79,14 +80,12 @@ def __init__(self, cfg, encoder, decoder):
self.fc_out = nn.Linear(
self.decoder.embed_tokens.embedding_dim,
self.decoder.embed_tokens.num_embeddings,
bias=False,
)
self.fc_out.weight = self.decoder.embed_tokens.weight
else:
self.fc_out = nn.Linear(
cfg.joint_dim, self.decoder.embed_tokens.num_embeddings, bias=False
cfg.joint_dim, self.decoder.embed_tokens.num_embeddings
)
nn.init.normal_(self.fc_out.weight, mean=0, std=cfg.joint_dim**-0.5)
self.fc_out = nn.utils.weight_norm(self.fc_out, name="weight")

self.cfg = cfg
Expand Down Expand Up @@ -144,6 +143,7 @@ def build_model(cls, cfg, task):
kernel_sizes,
strides,
in_channels=task.feat_in_channels,
apply_batchnorm=cfg.encoder.conv_apply_batchnorm,
)
if out_channels is not None
else None
Expand Down Expand Up @@ -310,3 +310,8 @@ def get_normalized_probs(
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)

def prepare_for_inference_(self, cfg: DictConfig):
"""Prepare model for inference."""
self.fc_out = nn.utils.remove_weight_norm(self.fc_out)
super().prepare_for_inference_(cfg)
35 changes: 25 additions & 10 deletions espresso/modules/speech_convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,19 @@


class ConvBNReLU(nn.Module):
"""Sequence of convolution-BatchNorm-ReLU layers."""

def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
"""Sequence of convolution-[BatchNorm]-ReLU layers.
Args:
out_channels (int): the number of output channels of conv layer
kernel_sizes (int or tuple): kernel sizes
strides (int or tuple): strides
in_channels (int, optional): the number of input channels (default: 1)
apply_batchnorm (bool, optional): if True apply BatchNorm after each convolution layer (default: True)
"""

def __init__(
self, out_channels, kernel_sizes, strides, in_channels=1, apply_batchnorm=True
):
super().__init__()
if not has_packaging:
raise ImportError("Please install packaging with: pip install packaging")
Expand All @@ -35,7 +45,7 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
assert num_layers == len(kernel_sizes) and num_layers == len(strides)

self.convolutions = nn.ModuleList()
self.batchnorms = nn.ModuleList()
self.batchnorms = nn.ModuleList() if apply_batchnorm else None
for i in range(num_layers):
self.convolutions.append(
Convolution2d(
Expand All @@ -45,7 +55,8 @@ def __init__(self, out_channels, kernel_sizes, strides, in_channels=1):
self.strides[i],
)
)
self.batchnorms.append(nn.BatchNorm2d(out_channels[i]))
if apply_batchnorm:
self.batchnorms.append(nn.BatchNorm2d(out_channels[i]))

def output_lengths(self, in_lengths: Union[torch.Tensor, int]):
out_lengths = in_lengths
Expand All @@ -65,18 +76,22 @@ def output_lengths(self, in_lengths: Union[torch.Tensor, int]):
return out_lengths

def forward(self, src, src_lengths):
# B X T X C -> B X (input channel num) x T X (C / input channel num)
# B x T x C -> B x (input channel num) x T x (C / input channel num)
x = src.view(
src.size(0),
src.size(1),
self.in_channels,
src.size(2) // self.in_channels,
).transpose(1, 2)
for conv, bn in zip(self.convolutions, self.batchnorms):
x = F.relu(bn(conv(x)))
# B X (output channel num) x T X C' -> B X T X (output channel num) X C'
if self.batchnorms is not None:
for conv, bn in zip(self.convolutions, self.batchnorms):
x = F.relu(bn(conv(x)))
else:
for conv in self.convolutions:
x = F.relu(conv(x))
# B x (output channel num) x T x C' -> B x T x (output channel num) x C'
x = x.transpose(1, 2)
# B X T X (output channel num) X C' -> B X T X C
# B x T x (output channel num) x C' -> B x T x C
x = x.contiguous().view(x.size(0), x.size(1), x.size(2) * x.size(3))

x_lengths = self.output_lengths(src_lengths)
Expand Down

0 comments on commit 5d32606

Please sign in to comment.