Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Merge pull request #182 from zhu-han/vgg_frontend
Browse files Browse the repository at this point in the history
Add VGG frontend
  • Loading branch information
danpovey authored May 2, 2021
2 parents c5ffa3f + d535568 commit b7f76b6
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 28 deletions.
59 changes: 59 additions & 0 deletions egs/librispeech/asr/simple_v1/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,65 @@ listed below.

# LibriSpeech MMI training results (Conformer)

## 2021-05-02

(Han Zhu): Results with VGG frontend.

Training log and tensorboard log can be found at <https://github.com/k2-fsa/snowfall/pull/182>.

Decoding results (WER) of final model averaged over last 5 epochs (i.e. epochs 5 to 9.) and each epoch model without model averaging are
listed below.

```
# average over last 5 epochs (LM rescoring with whole lattice)
2021-05-02 00:36:42,886 INFO [common.py:381] [test-clean] %WER 5.55% [2916 / 52576, 548 ins, 172 del, 2196 sub ]
2021-05-02 00:47:15,544 INFO [common.py:381] [test-other] %WER 15.32% [8021 / 52343, 1270 ins, 501 del, 6250 sub ]
# average over last 5 epochs
2021-05-01 23:35:17,891 INFO [common.py:381] [test-clean] %WER 6.65% [3494 / 52576, 457 ins, 293 del, 2744 sub ]
2021-05-01 23:37:23,141 INFO [common.py:381] [test-other] %WER 17.68% [9252 / 52343, 1020 ins, 858 del, 7374 sub ]
# epoch 0
2021-05-02 01:09:52,745 INFO [common.py:381] [test-clean] %WER 21.68% [11396 / 52576, 1438 ins, 998 del, 8960 sub ]
2021-05-02 01:11:14,618 INFO [common.py:381] [test-other] %WER 45.48% [23808 / 52343, 2571 ins, 2370 del, 18867 sub ]
# epoch 1
2021-05-02 01:12:49,179 INFO [common.py:381] [test-clean] %WER 11.76% [6184 / 52576, 695 ins, 683 del, 4806 sub ]
2021-05-02 01:14:11,675 INFO [common.py:381] [test-other] %WER 29.74% [15569 / 52343, 1442 ins, 1937 del, 12190 sub ]
# epoch 2
2021-05-02 01:15:46,336 INFO [common.py:381] [test-clean] %WER 9.45% [4966 / 52576, 552 ins, 487 del, 3927 sub ]
2021-05-02 01:17:08,992 INFO [common.py:381] [test-other] %WER 24.86% [13013 / 52343, 1194 ins, 1685 del, 10134 sub ]
# epoch 3
2021-05-02 01:18:43,584 INFO [common.py:381] [test-clean] %WER 9.49% [4987 / 52576, 549 ins, 686 del, 3752 sub ]
2021-05-02 01:20:08,417 INFO [common.py:381] [test-other] %WER 25.26% [13220 / 52343, 1029 ins, 2292 del, 9899 sub ]
# epoch 4
2021-05-02 01:21:43,498 INFO [common.py:381] [test-clean] %WER 8.00% [4207 / 52576, 492 ins, 382 del, 3333 sub ]
2021-05-02 01:23:06,132 INFO [common.py:381] [test-other] %WER 20.88% [10929 / 52343, 1056 ins, 1188 del, 8685 sub ]
# epoch 5
2021-05-02 01:24:39,382 INFO [common.py:381] [test-clean] %WER 7.89% [4148 / 52576, 500 ins, 347 del, 3301 sub ]
2021-05-02 01:26:02,202 INFO [common.py:381] [test-other] %WER 21.10% [11043 / 52343, 1233 ins, 1105 del, 8705 sub ]
# epoch 6
2021-05-02 01:27:35,616 INFO [common.py:381] [test-clean] %WER 7.72% [4058 / 52576, 471 ins, 380 del, 3207 sub ]
2021-05-02 01:28:58,678 INFO [common.py:381] [test-other] %WER 20.40% [10677 / 52343, 1106 ins, 1174 del, 8397 sub ]
# epoch 7
2021-05-02 01:30:32,897 INFO [common.py:381] [test-clean] %WER 7.40% [3893 / 52576, 470 ins, 349 del, 3074 sub ]
2021-05-02 01:31:54,306 INFO [common.py:381] [test-other] %WER 19.61% [10264 / 52343, 1037 ins, 1047 del, 8180 sub ]
# epoch 8
2021-05-02 01:33:28,578 INFO [common.py:381] [test-clean] %WER 7.40% [3890 / 52576, 489 ins, 329 del, 3072 sub ]
2021-05-02 01:34:52,473 INFO [common.py:381] [test-other] %WER 19.70% [10312 / 52343, 1157 ins, 1009 del, 8146 sub ]
# epoch 9
2021-05-02 01:36:30,299 INFO [common.py:381] [test-clean] %WER 7.32% [3848 / 52576, 525 ins, 321 del, 3002 sub ]
2021-05-02 01:37:52,445 INFO [common.py:381] [test-other] %WER 19.93% [10430 / 52343, 1251 ins, 956 del, 8223 sub ]
```

## 2021-03-26

Results when adding SpecAugment with the schedule proposed in the original paper that introduces it;
Expand Down
9 changes: 5 additions & 4 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def main():

output_beam_size = args.output_beam_size

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa')
exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

logging.info(f'output_beam_size: {output_beam_size}')
Expand Down Expand Up @@ -274,16 +274,17 @@ def main():
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
num_decoder_layers=num_decoder_layers,
vgg_frontend=True)
else:
model = Conformer(
num_features=80,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)

num_decoder_layers=num_decoder_layers,
vgg_frontend=True)
model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False)

if avg == 1:
Expand Down
8 changes: 5 additions & 3 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def run(rank, world_size, args):
fix_random_seed(42)
setup_dist(rank, world_size, args.master_port)

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa')
exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
setup_logger(f'{exp_dir}/log/log-train-{rank}')
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
Expand Down Expand Up @@ -482,15 +482,17 @@ def run(rank, world_size, args):
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
num_decoder_layers=num_decoder_layers,
vgg_frontend=True)
else:
model = Conformer(
num_features=80,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
num_decoder_layers=num_decoder_layers,
vgg_frontend=True)

model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

Expand Down
23 changes: 12 additions & 11 deletions snowfall/models/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,19 @@ class Conformer(Transformer):
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""

def __init__(self, num_features: int, num_classes: int, subsampling_factor: int = 4,
d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048,
num_encoder_layers: int = 12, num_decoder_layers: int = 6,
dropout: float = 0.1, cnn_module_kernel: int = 31,
normalize_before: bool = True) -> None:
num_encoder_layers: int = 12, num_decoder_layers: int = 6,
dropout: float = 0.1, cnn_module_kernel: int = 31,
normalize_before: bool = True, vgg_frontend: bool = False) -> None:
super(Conformer, self).__init__(num_features=num_features, num_classes=num_classes, subsampling_factor=subsampling_factor,
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers,
dropout=dropout, normalize_before=normalize_before)
dropout=dropout, normalize_before=normalize_before, vgg_frontend=vgg_frontend)

self.encoder_pos = RelPositionalEncoding(d_model, dropout)

encoder_layer = ConformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, normalize_before)
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropou
self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module

self.ff_scale = 0.5

self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
Expand Down Expand Up @@ -324,7 +325,7 @@ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.) -> None:
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

Expand Down Expand Up @@ -575,14 +576,14 @@ def multi_head_attention_forward(self, query: Tensor,
assert key_padding_mask.size(1) == src_len, "{} == {}".format(key_padding_mask.size(1), src_len)


q = q.transpose(0, 1) # (batch, time1, head, d_k)
q = q.transpose(0, 1) # (batch, time1, head, d_k)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)

q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) # (batch, head, time1, d_k)

q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) # (batch, head, time1, d_k)

# compute attention score
Expand Down Expand Up @@ -700,12 +701,12 @@ def forward(self, x: Tensor) -> Tensor:

x = self.pointwise_conv2(x) # (batch, channel, time)

return x.permute(2, 0, 1)
return x.permute(2, 0, 1)


class Swish(torch.nn.Module):
"""Construct an Swish object."""

def forward(self, x: Tensor) -> Tensor:
"""Return Swich activation function."""
return x * torch.sigmoid(x)
return x * torch.sigmoid(x)
106 changes: 96 additions & 10 deletions snowfall/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,23 @@ class Transformer(AcousticModel):
num_decoder_layers (int): number of decoder layers
dropout (float): dropout rate
normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend.
"""

def __init__(self, num_features: int, num_classes: int, subsampling_factor: int = 4,
d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048,
num_encoder_layers: int = 12, num_decoder_layers: int = 6,
dropout: float = 0.1, normalize_before: bool = True) -> None:
dropout: float = 0.1, normalize_before: bool = True,
vgg_frontend: bool = False) -> None:
super().__init__()
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")

self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_embed = (VggSubsampling(num_features, d_model) if vgg_frontend else
Conv2dSubsampling(num_features, d_model))
self.encoder_pos = PositionalEncoding(d_model, dropout)

encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
Expand Down Expand Up @@ -132,7 +135,7 @@ def decoder_forward(self, x: Tensor, encoder_mask: Tensor, supervision: Dict, gr
x: Tensor of dimension (input_length, batch_size, d_model).
encoder_mask: Mask tensor of dimension (batch_size, input_length)
supervision: Supervison in lhotse format, get from batch['supervisions']
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones)
, graph_compiler.words and graph_compiler.oov
Returns:
Expand Down Expand Up @@ -166,7 +169,7 @@ def decoder_forward(self, x: Tensor, encoder_mask: Tensor, supervision: Dict, gr

class TransformerEncoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before,
Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
Expand Down Expand Up @@ -243,9 +246,9 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,

class TransformerDecoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerDecoderLayer. Add support of normalize_before,
Modified from torch.nn.TransformerDecoderLayer. Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
Expand Down Expand Up @@ -357,7 +360,7 @@ class Conv2dSubsampling(nn.Module):
"""

def __init__(self, idim: int, odim: int) -> None:
"""Construct an Conv2dSubsampling object."""
"""Construct a Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2),
Expand Down Expand Up @@ -385,6 +388,73 @@ def forward(self, x: Tensor) -> Tensor:
return x


class VggSubsampling(nn.Module):
"""Trying to follow the setup described here https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Args:
idim: Input dimension.
odim: Output dimension.
"""

def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object. This uses 2 VGG blocks with 2
Conv2d layers each, subsampling its input by a factor of 4 in the
time dimensions.
Args:
idim: Number of features at input, e.g. 40 or 80 for MFCC
(will be treated as the image height).
odim: Output dimension (number of features), e.g. 256
"""
super(VggSubsampling, self).__init__()

cur_channels = 1
layers = []
block_dims = [32,64]

# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the 2nd convolution,
# so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(torch.nn.Conv2d(in_channels=cur_channels, out_channels=block_dim,
kernel_size=3, padding=1, stride=1))
layers.append(torch.nn.ReLU())
layers.append(torch.nn.Conv2d(in_channels=block_dim, out_channels=block_dim,
kernel_size=3, padding=0, stride=1))
layers.append(torch.nn.MaxPool2d(kernel_size=2, stride=2,
padding=0, ceil_mode=True))
cur_channels = block_dim

self.layers = nn.Sequential(*layers)

self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim)


def forward(self, x: Tensor) -> Tensor:
"""Subsample x.
Args:
x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim).
Returns:
torch.Tensor: Subsampled tensor of dimension (batch_size, input_length', d_model).
where input_length' == (((input_length - 1) // 2) - 1) // 2
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x


class PositionalEncoding(nn.Module):
"""
Positional encoding.
Expand Down Expand Up @@ -443,7 +513,7 @@ class Noam(object):
"""
Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
model_size: attention dimension of the transformer model
Expand Down Expand Up @@ -549,7 +619,7 @@ def forward(self, x: Tensor, target: Tensor) -> Tensor:
x: prediction of dimention (batch_size, input_length, number_of_classes).
target: target masked with self.padding_id of dimention (batch_size, input_length).
Returns:
Returns:
torch.Tensor: scalar float value
"""
assert x.size(2) == self.size
Expand Down Expand Up @@ -642,7 +712,7 @@ def generate_square_subsequent_mask(sz: int) -> Tensor:
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
Args:
Args:
sz: mask size
Returns:
Expand Down Expand Up @@ -744,3 +814,19 @@ def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tenso
ys = [torch.tensor(y) for y in ys]

return ys



def test_transformer():
t = Transformer(40, 1281)
T = 200
f = torch.rand(31, 40, T)
g, _, _ = t(f)
assert g.shape == (31, 1281, (((T-1)//2)-1)//2)

def main():
test_transformer()


if __name__ == '__main__':
main()

0 comments on commit b7f76b6

Please sign in to comment.