Skip to content

Commit

Permalink
Update llama_flash_attention.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoguany authored Nov 26, 2023
1 parent ae0cd54 commit 9280a0b
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions src/lmflow/utils/flash_attention/llama_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@

from einops import rearrange

#try to import flash_attn 2.x.x, if not, import flash_attn 1.x.x
try:
from flash_attn.flash_attn_interface import flash_attn_func
except:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func as flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_func,flash_attn_varlen_func

from flash_attn.bert_padding import unpad_input, pad_input

Expand Down Expand Up @@ -70,8 +66,34 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

# below output will have shape (batch_size, seqlen, nheads, headdim)
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)
dropout = 0.0 if not self.training else self.attention_dropout

# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, q_len
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
causal=self.is_causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)

else:
# below output will have shape (batch_size, seqlen, nheads, headdim)
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)

if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -121,4 +143,4 @@ def _prepare_decoder_attention_mask(

def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

0 comments on commit 9280a0b

Please sign in to comment.