Skip to content

Commit

Permalink
Remove float32 casts from MHA
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Granger <[email protected]>
  • Loading branch information
CGranger-sorenson authored Aug 23, 2024
1 parent e772dbf commit 3a8cefe
Showing 1 changed file with 163 additions and 178 deletions.
341 changes: 163 additions & 178 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,9 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None):
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
out = self.forward_attention(v, scores, mask)
if cache is None:
return out
else:
Expand Down Expand Up @@ -210,39 +205,34 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
"""
key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)

# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]

scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)

out = self.forward_attention(v, scores, mask)
out = self.forward_attention(v, scores, mask)

if cache is None:
return out
Expand Down Expand Up @@ -317,150 +307,145 @@ def forward(self, query, key, value, pad_mask, pos_emb, cache=None):

key, value, query, cache = self.update_cache(key=key, value=value, query=query, cache=cache)

if torch.is_autocast_enabled():
query, key, value = query.to(torch.float32), key.to(torch.float32), value.to(torch.float32)

# temporary until we solve this more gracefully
with avoid_float16_autocast_context():
q, k, v = self.forward_qkv(query, key, value)
n_batch, _, T, _ = q.size()

w = max(self.att_context_size[0], self.att_context_size[1])
if w <= 0:
raise ValueError("When using local attention, context size must be set > 0")
pad_len = (2 * w - T % (2 * w)) % (2 * w) # pad time to 2w
q = F.pad(q, (0, 0, 0, pad_len)) # (batch, head, time, size)
k = F.pad(k, (0, 0, 0, pad_len)) # (batch, head, time, size)
v = F.pad(v, (0, 0, 0, pad_len)) # (batch, head, time, size)
mask = F.pad(pad_mask, (0, pad_len), value=1.0)

q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) # (batch, head, time, size)
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) # (batch, head, time, size)

diagonal_matrix_ac = self.sliding_chunks_matmul_qk(
q_with_bias_u, k, w, padding_value=0.0
) # (batch, head, time, 2w + 1)

# add relative positional embedding

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k).transpose(1, 2)
# (batch, head, 2w, size)
diagonal_matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# (batch, head, time, 2w + 1)

start_pos = w - self.att_context_size[0]
end_pos = w + self.att_context_size[1]

diagonal_matrix_ac[:, :, :, : self.att_context_size[0]] += diagonal_matrix_bd[
:, :, :, : self.att_context_size[0]
]
diagonal_matrix_ac[:, :, :, -(self.att_context_size[1] + 1) :] += diagonal_matrix_bd[
:, :, :, self.att_context_size[0] :
]
scores = diagonal_matrix_ac / self.s_d_k
# (batch, head, time, 2w + 1)

# mask invalid positions
scores[:, :, :, :start_pos] = -10000.0
scores[:, :, :, end_pos + 1 :] = -10000.0

# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (bsz x seq_len) to (bsz x num_heads x seqlen x hidden_size)
mask = mask.unsqueeze(dim=1).unsqueeze(dim=-1)
# cast to float/half then replace 1's with -inf
float_mask = mask.type_as(scores).masked_fill(mask, -10000.0)
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding
d_mask = self.sliding_chunks_matmul_qk(ones, float_mask, w, padding_value=0.0)
# (batch, head, time, 2w + 1)

scores += d_mask

if self.global_tokens > 0:

# create q, k, v for global attn
if self.global_attn_separate:
global_q = self.global_q(query).view(n_batch, -1, self.h, self.d_k)
global_k = self.global_k(key).view(n_batch, -1, self.h, self.d_k)
global_v = self.global_v(value).view(n_batch, -1, self.h, self.d_k)
global_q = global_q.transpose(1, 2)
global_k = global_k.transpose(1, 2)
global_v = global_v.transpose(1, 2)
global_q = F.pad(global_q, (0, 0, 0, pad_len)) # (batch, head, time, size)
global_k = F.pad(global_k, (0, 0, 0, pad_len)) # (batch, head, time, size)
global_v = F.pad(global_v, (0, 0, 0, pad_len)) # (batch, head, time, size)
else:
global_q, global_k, global_v = q, k, v

global_q /= self.s_d_k

# assign which tokens are global
is_index_global_attn = torch.zeros_like(pad_mask)
is_index_global_attn[
:, : self.global_tokens * self.global_tokens_spacing : self.global_tokens_spacing
] = 1.0

# compute global attn indices
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn=is_index_global_attn)

# calculate global attn probs with global keys
# (batch, time, head, max_num_global_attn_indices)
global_key_attn = self._compute_global_key_attn(
query=global_q.transpose(1, 2),
key=global_k.transpose(1, 2),
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
).transpose(1, 2)

# concat to local_attn_probs
# (batch, time, head, max_num_global_attn_indices + 2*w)
scores = torch.cat((global_key_attn, scores), dim=-1)

# free memory
del global_key_attn

attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
p_attn = self.dropout(attn)
# (batch, head, time, 2w + 1)

if self.global_tokens > 0:
# compute sum of global and local attn
out = self._compute_attn_output_with_global_indices(
value=v,
attn_probs=p_attn,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
w=w,
)
q, k, v = self.forward_qkv(query, key, value)
n_batch, _, T, _ = q.size()

w = max(self.att_context_size[0], self.att_context_size[1])
if w <= 0:
raise ValueError("When using local attention, context size must be set > 0")
pad_len = (2 * w - T % (2 * w)) % (2 * w) # pad time to 2w
q = F.pad(q, (0, 0, 0, pad_len)) # (batch, head, time, size)
k = F.pad(k, (0, 0, 0, pad_len)) # (batch, head, time, size)
v = F.pad(v, (0, 0, 0, pad_len)) # (batch, head, time, size)
mask = F.pad(pad_mask, (0, pad_len), value=1.0)

q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) # (batch, head, time, size)
q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) # (batch, head, time, size)

diagonal_matrix_ac = self.sliding_chunks_matmul_qk(
q_with_bias_u, k, w, padding_value=0.0
) # (batch, head, time, 2w + 1)

# add relative positional embedding

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k).transpose(1, 2)
# (batch, head, 2w, size)
diagonal_matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# (batch, head, time, 2w + 1)

start_pos = w - self.att_context_size[0]
end_pos = w + self.att_context_size[1]

diagonal_matrix_ac[:, :, :, : self.att_context_size[0]] += diagonal_matrix_bd[
:, :, :, : self.att_context_size[0]
]
diagonal_matrix_ac[:, :, :, -(self.att_context_size[1] + 1) :] += diagonal_matrix_bd[
:, :, :, self.att_context_size[0] :
]
scores = diagonal_matrix_ac / self.s_d_k
# (batch, head, time, 2w + 1)

# mask invalid positions
scores[:, :, :, :start_pos] = -10000.0
scores[:, :, :, end_pos + 1 :] = -10000.0

# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
# from (bsz x seq_len) to (bsz x num_heads x seqlen x hidden_size)
mask = mask.unsqueeze(dim=1).unsqueeze(dim=-1)
# cast to float/half then replace 1's with -inf
float_mask = mask.type_as(scores).masked_fill(mask, -10000.0)
ones = float_mask.new_ones(size=float_mask.size()) # tensor of ones
# diagonal mask with zeros everywhere and -inf inplace of padding
d_mask = self.sliding_chunks_matmul_qk(ones, float_mask, w, padding_value=0.0)
# (batch, head, time, 2w + 1)

scores += d_mask

if self.global_tokens > 0:

# create q, k, v for global attn
if self.global_attn_separate:
global_q = self.global_q(query).view(n_batch, -1, self.h, self.d_k)
global_k = self.global_k(key).view(n_batch, -1, self.h, self.d_k)
global_v = self.global_v(value).view(n_batch, -1, self.h, self.d_k)
global_q = global_q.transpose(1, 2)
global_k = global_k.transpose(1, 2)
global_v = global_v.transpose(1, 2)
global_q = F.pad(global_q, (0, 0, 0, pad_len)) # (batch, head, time, size)
global_k = F.pad(global_k, (0, 0, 0, pad_len)) # (batch, head, time, size)
global_v = F.pad(global_v, (0, 0, 0, pad_len)) # (batch, head, time, size)
else:
# compute local attn only
out = self.sliding_chunks_matmul_pv(p_attn, v, w)

out = out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]

if self.global_tokens > 0:
out_global_to_all = self._compute_out_global_to_all(
query=global_q,
key=global_k,
value=global_v,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=mask,
)

# overwrite values with global attention
out[is_index_global_attn_nonzero] = out_global_to_all
global_q, global_k, global_v = q, k, v

global_q /= self.s_d_k

# assign which tokens are global
is_index_global_attn = torch.zeros_like(pad_mask)
is_index_global_attn[
:, : self.global_tokens * self.global_tokens_spacing : self.global_tokens_spacing
] = 1.0

# compute global attn indices
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn=is_index_global_attn)

# calculate global attn probs with global keys
# (batch, time, head, max_num_global_attn_indices)
global_key_attn = self._compute_global_key_attn(
query=global_q.transpose(1, 2),
key=global_k.transpose(1, 2),
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
).transpose(1, 2)

# concat to local_attn_probs
# (batch, time, head, max_num_global_attn_indices + 2*w)
scores = torch.cat((global_key_attn, scores), dim=-1)

# free memory
del global_key_attn

attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
p_attn = self.dropout(attn)
# (batch, head, time, 2w + 1)

if self.global_tokens > 0:
# compute sum of global and local attn
out = self._compute_attn_output_with_global_indices(
value=v,
attn_probs=p_attn,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
w=w,
)
else:
# compute local attn only
out = self.sliding_chunks_matmul_pv(p_attn, v, w)

out = out.reshape(n_batch, -1, self.h * self.d_k)[:, :T]

if self.global_tokens > 0:
out_global_to_all = self._compute_out_global_to_all(
query=global_q,
key=global_k,
value=global_v,
max_num_global_attn_indices=max_num_global_attn_indices,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
is_index_masked=mask,
)

# overwrite values with global attention
out[is_index_global_attn_nonzero] = out_global_to_all

ret = self.linear_out(out)

Expand Down

0 comments on commit 3a8cefe

Please sign in to comment.