Skip to content

Commit

Permalink
fixed masked flash attention (#2589)
Browse files Browse the repository at this point in the history
* fixed masked flash attention
  • Loading branch information
l-k-11235 authored Jun 27, 2024
1 parent 63f07fc commit 9991c8d
Showing 1 changed file with 10 additions and 20 deletions.
30 changes: 10 additions & 20 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ def forward(
"""
# 1) Project key, value, and query.
# as a reminder at training layer_cache[0] remains False
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
if self.layer_cache[0]:
# Retrieve keys and values from the KV cache (decoding mode only).
if self.attn_type == "self":
Expand Down Expand Up @@ -484,6 +483,16 @@ def forward(
key = key[:, :, 1:, :]
value = value[:, :, 1:, :]

if step == 0:
key_pad_mask = self.layer_cache[1].get("key_pad_mask", None)
if key_pad_mask is not None:
x = key_pad_mask.expand(
-1, self.head_count // self.parallel_gpu, -1
)
x = x.unsqueeze(3)
x = x.expand(-1, -1, -1, value.size(3))
value = value.masked_fill(x, 0)

self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value

Expand Down Expand Up @@ -565,19 +574,6 @@ def forward(
self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value

if key_pad_mask is not None:
# Increase the cached key pad mask by concatenation.
# For decoding only.
if step > 0:
y = torch.zeros(
(key_pad_mask.size(0), key_pad_mask.size(1), 1),
dtype=torch.bool,
device=key_pad_mask.device,
)
self.layer_cache[1]["key_pad_mask"] = torch.cat(
(key_pad_mask, y), 2
)
key_pad_mask = self.layer_cache[1]["key_pad_mask"]
else:
# Retrieve keys and values from linear layers (training mode).
key = self.maybe_ckpt(self.linear_keys, key)
Expand Down Expand Up @@ -706,8 +702,6 @@ def forward(
scores = self.alibi(scores)

scores = scores.float()
if key_pad_mask is not None and mask is None:
mask = key_pad_mask.unsqueeze(1)

if mask is not None:
# not 100% necessary but expand to nb of heads
Expand All @@ -727,10 +721,6 @@ def forward(
attn_output.add_(relative_matmul(drop_attn, relations_values, False))

context = unshape(attn_output)
if key_pad_mask is not None:
if key_pad_mask.size(0) > 1 and context.size(1) > 1:
x = key_pad_mask.squeeze(1).unsqueeze(2).expand(-1, -1, context.size(2))
context = context.masked_fill(x, 0)

if self.layer_cache[0]:
attn_output = self.final_linear(context)
Expand Down

0 comments on commit 9991c8d

Please sign in to comment.