From 0491eaafb9c4cf3870443d0f57465b3d89e5ecfa Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 20 Mar 2023 10:05:23 -0700 Subject: [PATCH] in coarse transformer, make sure that coarse tokens attending to semantic tokens (cross attention) does not use relative positions --- audiolm_pytorch/audiolm_pytorch.py | 38 ++++++++++++++++++++++++++---- audiolm_pytorch/version.py | 2 +- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/audiolm_pytorch/audiolm_pytorch.py b/audiolm_pytorch/audiolm_pytorch.py index e1c7521..77a5f11 100644 --- a/audiolm_pytorch/audiolm_pytorch.py +++ b/audiolm_pytorch/audiolm_pytorch.py @@ -198,7 +198,12 @@ def __init__( self.net.append(nn.Linear(dim, heads)) - def forward(self, n, device = torch.device('cpu')): + @property + def device(self): + return next(self.parameters()).device + + def forward(self, n): + device = self.device pos = torch.arange(n, device = device) rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j')) rel_pos += (n - 1) @@ -432,7 +437,7 @@ def forward( if exists(attn_bias): rel_pos_bias = attn_bias else: - rel_pos_bias = maybe(self.rel_pos_bias)(n, device = device) + rel_pos_bias = maybe(self.rel_pos_bias)(n) self_attn_kwargs = dict() if self.cond_as_self_attn_prefix: @@ -623,6 +628,8 @@ def __init__( text_dim = default(cond_dim, get_encoded_dim(t5_name)) self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity() + self.cross_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1)) + self.transformer = Transformer( dim = dim, depth = depth, @@ -677,6 +684,7 @@ def forward( return_only_coarse_logits = False ): b, device = semantic_token_ids.shape[0], semantic_token_ids.device + arange = partial(torch.arange, device = device) has_text = exists(text) or exists(text_embeds) assert not (self.has_condition ^ has_text) @@ -699,7 +707,7 @@ def forward( coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids)) - offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device) + offsets = self.codebook_size * arange(self.num_coarse_quantizers) offsets = repeat(offsets, 'q -> 1 (n q)', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers)) offsets = offsets[:, :coarse_token_ids.shape[-1]] coarse_token_ids = coarse_token_ids + offsets @@ -723,7 +731,29 @@ def forward( coarse_tokens ), dim = 1) - tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask) + # engineer the attention bias so that cross attention is not dominated by relative positions + + seq_len = tokens.shape[-2] + attn_bias = self.transformer.rel_pos_bias(seq_len) + + is_semantic = arange(seq_len) < (semantic_seq_len + 1) # semantic seq len + start token + is_cross_attn = rearrange(is_semantic, 'i -> i 1') ^ rearrange(is_semantic, 'j -> 1 j') + + attn_bias = torch.where( + is_cross_attn, + self.cross_attn_bias, + attn_bias + ) + + # attend + + tokens = self.transformer( + tokens, + context = text_embeds, + attn_bias = attn_bias, + self_attn_mask = self_attn_mask, + context_mask = text_mask + ) pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):] diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index caf9513..f8ab8c2 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.7' +__version__ = '0.24.0'