-
Notifications
You must be signed in to change notification settings - Fork 282
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
185b1b5
commit f40c4cc
Showing
3 changed files
with
333 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Flash attention Kernels using NKI on Neuron. Tested on trn1 & trn2.""" | ||
from functools import partial | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
# Import needed to enable JAX cache on Neuron | ||
import jax_neuronx # pylint: disable=unused-import | ||
import neuronxcc.nki.language as nl | ||
from jax import custom_vjp | ||
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd | ||
|
||
Tensor = jax.Array | ||
|
||
lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1 | ||
|
||
|
||
@partial(custom_vjp, nondiff_argnums=(4, 5, 6)) | ||
def flash_attention( | ||
query: Tensor, | ||
key: Tensor, | ||
value: Tensor, | ||
bias: Tensor, | ||
causal: bool = False, | ||
softmax_scale: float = 1.0, | ||
dropout_rate: float = 0.0, | ||
): | ||
"""Wraps _mha_forward for custom vjp. | ||
Args: | ||
query: Query of shape [batch_size, target_length, num_heads, per_head_dim]. | ||
key: Key of shape [batch_size, source_length, num_heads, per_head_dim]. | ||
value: Value of shape [batch_size, source_length, num_heads, per_head_dim]. | ||
bias: Optional logit biases of shape [batch_size, num_heads, target_length, source_length]. | ||
softmax_scale: Optional scale to apply to softmax. Defaults to 1. | ||
causal: Whether to apply causal mask. | ||
dropout_rate: Dropout rate. Default to 0.0 (no dropout). | ||
Returns: | ||
The attention outputs of shape [batch_size, target_length, num_heads, per_head_dim]. | ||
""" | ||
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate) | ||
return out | ||
|
||
|
||
def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate): | ||
"""Computes attention outputs following FlashAttention. | ||
See also `_mha_backward` for the backward pass. | ||
Args: | ||
query: Input query. | ||
key: Input key. | ||
value: Input value. | ||
bias: Input bias. | ||
causal: Input segment_ids. | ||
softmax_scale: Softmax scale to use in the kernel. | ||
dropout_rate: Dropout rate to use in the kernel. | ||
""" | ||
# Get the batch size, sequence lengths, number of heads, and hidden dimension. | ||
batch_size, _, num_heads, _ = query.shape | ||
|
||
# Transpose the query, key, and value tensors. | ||
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len]. | ||
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len]. | ||
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model]. | ||
|
||
seed = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads. | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias is not None: | ||
assert ( | ||
bias.ndim == 4 | ||
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
) | ||
else: | ||
attn_output, lse = flash_fwd[grid]( | ||
q, | ||
k, | ||
v, | ||
seed, | ||
use_causal_mask=causal, | ||
softmax_scale=softmax_scale, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
) | ||
# Transpose the output back to the original shape. | ||
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model]. | ||
|
||
return attn_output, (lse, attn_output, q, k, v, bias) | ||
|
||
|
||
def _mha_backward(causal, softmax_scale, dropout_rate, res, d_attn_output): | ||
lse, o, q, k, v, bias = res | ||
batch_size, num_heads, _, _ = q.shape | ||
|
||
# Transpose the input tensors. | ||
o = o.transpose(0, 2, 3, 1) | ||
dy = d_attn_output.transpose(0, 2, 3, 1) | ||
|
||
# Transpose v tensor. | ||
v = jnp.transpose(v, axes=(0, 1, 3, 2)) | ||
seed = jnp.array([1]) | ||
|
||
# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads. | ||
if (num_heads % 2) == 0 and (num_heads // 2 > 0): | ||
grid = batch_size, nl.nc(lnc) * (num_heads // lnc) | ||
else: | ||
grid = batch_size, num_heads | ||
|
||
if bias is not None: | ||
assert ( | ||
bias.ndim == 4 | ||
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}" | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
bias, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
softmax_scale=softmax_scale, | ||
) | ||
else: | ||
d_query, d_key, d_value = flash_attn_bwd[grid]( | ||
q, | ||
k, | ||
v, | ||
o, | ||
dy, | ||
lse, | ||
seed, | ||
use_causal_mask=causal, | ||
mixed_precision=True, | ||
dropout_p=dropout_rate, | ||
softmax_scale=softmax_scale, | ||
) | ||
|
||
# Transpose the gradients back to the original shape. | ||
d_query = d_query.transpose(0, 3, 1, 2) | ||
d_key = d_key.transpose(0, 3, 1, 2) | ||
d_value = d_value.transpose(0, 3, 1, 2) | ||
|
||
return d_query, d_key, d_value, None | ||
|
||
|
||
flash_attention.defvjp(_mha_forward, _mha_backward) |
141 changes: 141 additions & 0 deletions
141
axlearn/common/flash_attention/neuron_attention_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright © 2024 Amazon Inc. | ||
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2.""" | ||
|
||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
from axlearn.common.flash_attention.neuron_attention import flash_attention | ||
from axlearn.common.flash_attention.utils import mha_reference | ||
|
||
if jax.default_backend() != "neuron": | ||
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,seq_len,num_heads,per_head_dim", | ||
[ | ||
(1, 2048, 1, 64), | ||
(2, 2048, 2, 64), | ||
(1, 2048, 1, 128), | ||
(2, 2048, 2, 128), | ||
(1, 2048, 8, 128), | ||
(2, 2048, 8, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("use_fwd", [True, False]) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("attention_bias_type", [None, "4d"]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32]) | ||
def test_fwd_against_ref( | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
attention_bias_type: bool, | ||
): | ||
softmax_scale = 1.0 / (per_head_dim**0.5) | ||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
||
if attention_bias_type == "4d": | ||
bias = jax.random.normal(k4, (batch_size, num_heads, seq_len, seq_len), dtype=input_dtype) | ||
else: | ||
bias = None | ||
|
||
o = flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
) | ||
o_ref = mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
) | ||
if input_dtype == jnp.float16: | ||
chex.assert_trees_all_close(o, o_ref, atol=0.07) | ||
elif input_dtype == jnp.float32: | ||
chex.assert_trees_all_close(o, o_ref, atol=0.03) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size,num_heads,seq_len,per_head_dim", | ||
[ | ||
(1, 1, 2048, 64), | ||
(2, 2, 2048, 64), | ||
(1, 1, 2048, 128), | ||
(2, 2, 2048, 128), | ||
(1, 8, 2048, 128), | ||
(2, 8, 2048, 128), | ||
], | ||
) | ||
@pytest.mark.parametrize("causal", [True, False]) | ||
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32]) | ||
@pytest.mark.parametrize("attention_bias_type", [None, "4d"]) | ||
def test_bwd_against_ref( | ||
batch_size: int, | ||
num_heads: int, | ||
seq_len: int, | ||
per_head_dim: int, | ||
causal: bool, | ||
input_dtype: jnp.dtype, | ||
attention_bias_type: bool, | ||
): | ||
softmax_scale = 1.0 / (per_head_dim**0.5) | ||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) | ||
q = jax.random.normal( | ||
jax.random.PRNGKey(k1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
k = jax.random.normal( | ||
jax.random.PRNGKey(k2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
v = jax.random.normal( | ||
jax.random.PRNGKey(k3), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype | ||
) | ||
|
||
if attention_bias_type == "4d": | ||
bias = jax.random.normal(k4, (batch_size, num_heads, seq_len, seq_len), dtype=input_dtype) | ||
else: | ||
bias = None | ||
segment_ids = None | ||
|
||
def fn(q, k, v, bias): | ||
return flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
).sum() | ||
|
||
def ref_fn(q, k, v, bias, segment_ids): | ||
return mha_reference( | ||
q, | ||
k, | ||
v, | ||
bias, | ||
segment_ids, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=0.0, | ||
).sum() | ||
|
||
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias) | ||
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) | ||
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters