Skip to content

Commit

Permalink
Update hf_decoder_model.py
Browse files Browse the repository at this point in the history
add a6000 support for flash attention
  • Loading branch information
yaoguany authored Nov 6, 2023
1 parent f846912 commit e888923
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@

GPU_SUPPORT_FLASH_ATTENTION = {
"A100": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A40": ["GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
"A40": ["GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A6000": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
}

try:
import flash_attn
if int(flash_attn.__version__.split(".")[0]) == 2:
GPU_SUPPORT_FLASH_ATTENTION = {
"A100": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A40": ["LlamaForCausalLM","GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
"A40": ["LlamaForCausalLM","GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A6000": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
}
except:
pass
Expand Down

0 comments on commit e888923

Please sign in to comment.