Skip to content

Commit

Permalink
Merge pull request #670 from yaoguany/flash_attn2
Browse files Browse the repository at this point in the history
Update hf_decoder_model.py
  • Loading branch information
research4pan authored Nov 6, 2023
2 parents f846912 + e888923 commit f49123d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@

GPU_SUPPORT_FLASH_ATTENTION = {
"A100": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A40": ["GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
"A40": ["GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A6000": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
}

try:
import flash_attn
if int(flash_attn.__version__.split(".")[0]) == 2:
GPU_SUPPORT_FLASH_ATTENTION = {
"A100": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A40": ["LlamaForCausalLM","GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
"A40": ["LlamaForCausalLM","GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A6000": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
}
except:
pass
Expand Down

0 comments on commit f49123d

Please sign in to comment.