-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention for Neuron #939
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe wait until this PR is checked in. From what i can tell, your PR also has the remat bug not fixed. #942 (review)
|
||
|
||
def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate): | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: end comments with . (here and everywhere)
f40c4cc
to
8a92182
Compare
8a92182
to
73a2808
Compare
key: Tensor, | ||
value: Tensor, | ||
bias: Tensor, | ||
causal: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we support segment ID? Or a more general masking fn (with optimized handling) is even better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, I am fine with leaving a TODO here, but it is a hard blocker for enabling it for our internal training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do segment IDs in a separate PR? That involves non-trivial work and needs some time.
Thanks for all the reviews @ruomingp @kelvin-zou. I resolved all the comments, please let me know if any more changes are needed. |
This PR adds support for flash attention kernel for Neuron implemented through Neuron Kernel Interface (NKI).
The flash attention kernel works with TRN1 and TRN2.
This PR is a newer version of #883 from a different fork. All comments from the previous PR are addressed in this one. It has dropout support.
Dropout and Segment ID support in the flash attention kernel is in progress and will be available at a later date.