Skip to content

Commit

Permalink
add more info when fail to import flash attn
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhen committed Mar 26, 2024
1 parent 4d908c0 commit 1195482
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
16 changes: 16 additions & 0 deletions readme/flash_attn2.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,19 @@ deepspeed --master_port=11000 \
```

Upgrade to LMFlow now and experience the future of language modeling!


## Known Issues
### 1. `undefined symbol` error
When importing the flash attention module, you may encounter `ImportError` saying `undefined symbol`:
```bash
>>> import flash_attn
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../anaconda3/envs/lmflow/lib/python3.9/site-packages/flash_attn/__init__.py", line 3, in <module>
from flash_attn.flash_attn_interface import flash_attn_func
File ".../anaconda3/envs/lmflow/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 4, in <module>
import flash_attn_2_cuda as flash_attn_cuda
ImportError: .../anaconda3/envs/lmflow/lib/python3.9/site-packages/flash_attn_2_cuda.cpython-39-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_8optionalIdEE
```
This MAY due to the incompatibility between the PyTorch version and the flash attention module, or the compiling process of flash attention. We've tested several approaches, either downgrade PyTorch OR upgrade the flash attention module works. If you still encounter this issue, please refer to [this issue](https://github.com/Dao-AILab/flash-attention/issues/451).
9 changes: 7 additions & 2 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@
"A40": ["LlamaForCausalLM","GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A6000": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
}
except:
pass
except Exception as e:
if e.__class__ == ModuleNotFoundError:
logger.warning(
"flash_attn is not installed. Install flash_attn for better performance."
)
else:
raise e

class HFDecoderModel(DecoderModel, Tunable):
r"""
Expand Down

0 comments on commit 1195482

Please sign in to comment.