Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention for Neuron #939

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link
Contributor

@apoorvtintin apoorvtintin commented Jan 21, 2025

This PR adds support for flash attention kernel for Neuron implemented through Neuron Kernel Interface (NKI).

The flash attention kernel works with TRN1 and TRN2.

This PR is a newer version of #883 from a different fork. All comments from the previous PR are addressed in this one. It has dropout support.

Dropout and Segment ID support in the flash attention kernel is in progress and will be available at a later date.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe wait until this PR is checked in. From what i can tell, your PR also has the remat bug not fixed. #942 (review)

axlearn/common/flash_attention/utils.py Show resolved Hide resolved


def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate):
# Get the batch size, sequence lengths, number of heads, and hidden dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: end comments with . (here and everywhere)

@apoorvtintin apoorvtintin force-pushed the mainline_upstream_fa branch 3 times, most recently from f40c4cc to 8a92182 Compare January 23, 2025 02:16
key: Tensor,
value: Tensor,
bias: Tensor,
causal: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we support segment ID? Or a more general masking fn (with optimized handling) is even better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, I am fine with leaving a TODO here, but it is a hard blocker for enabling it for our internal training.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do segment IDs in a separate PR? That involves non-trivial work and needs some time.

@apoorvtintin
Copy link
Contributor Author

Thanks for all the reviews @ruomingp @kelvin-zou. I resolved all the comments, please let me know if any more changes are needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants