Skip to content

Commit

Permalink
in coarse transformer, make sure that coarse tokens attending to sema…
Browse files Browse the repository at this point in the history
…ntic tokens (cross attention) does not use relative positions
lucidrains committed Mar 20, 2023
1 parent dd4784a commit 0491eaa
Showing 2 changed files with 35 additions and 5 deletions.
38 changes: 34 additions & 4 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
@@ -198,7 +198,12 @@ def __init__(

self.net.append(nn.Linear(dim, heads))

def forward(self, n, device = torch.device('cpu')):
@property
def device(self):
return next(self.parameters()).device

def forward(self, n):
device = self.device
pos = torch.arange(n, device = device)
rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j'))
rel_pos += (n - 1)
@@ -432,7 +437,7 @@ def forward(
if exists(attn_bias):
rel_pos_bias = attn_bias
else:
rel_pos_bias = maybe(self.rel_pos_bias)(n, device = device)
rel_pos_bias = maybe(self.rel_pos_bias)(n)

self_attn_kwargs = dict()
if self.cond_as_self_attn_prefix:
@@ -623,6 +628,8 @@ def __init__(
text_dim = default(cond_dim, get_encoded_dim(t5_name))
self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()

self.cross_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))

self.transformer = Transformer(
dim = dim,
depth = depth,
@@ -677,6 +684,7 @@ def forward(
return_only_coarse_logits = False
):
b, device = semantic_token_ids.shape[0], semantic_token_ids.device
arange = partial(torch.arange, device = device)

has_text = exists(text) or exists(text_embeds)
assert not (self.has_condition ^ has_text)
@@ -699,7 +707,7 @@ def forward(

coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids))

offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
offsets = self.codebook_size * arange(self.num_coarse_quantizers)
offsets = repeat(offsets, 'q -> 1 (n q)', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers))
offsets = offsets[:, :coarse_token_ids.shape[-1]]
coarse_token_ids = coarse_token_ids + offsets
@@ -723,7 +731,29 @@ def forward(
coarse_tokens
), dim = 1)

tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)
# engineer the attention bias so that cross attention is not dominated by relative positions

seq_len = tokens.shape[-2]
attn_bias = self.transformer.rel_pos_bias(seq_len)

is_semantic = arange(seq_len) < (semantic_seq_len + 1) # semantic seq len + start token
is_cross_attn = rearrange(is_semantic, 'i -> i 1') ^ rearrange(is_semantic, 'j -> 1 j')

attn_bias = torch.where(
is_cross_attn,
self.cross_attn_bias,
attn_bias
)

# attend

tokens = self.transformer(
tokens,
context = text_embeds,
attn_bias = attn_bias,
self_attn_mask = self_attn_mask,
context_mask = text_mask
)

pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):]

2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.23.7'
__version__ = '0.24.0'

0 comments on commit 0491eaa

Please sign in to comment.