Skip to content

Commit

Permalink
Flash Attention for Neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Jan 23, 2025
1 parent 185b1b5 commit f40c4cc
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 2 deletions.
168 changes: 168 additions & 0 deletions axlearn/common/flash_attention/neuron_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright © 2024 Amazon Inc.
"""Flash attention Kernels using NKI on Neuron. Tested on trn1 & trn2."""
from functools import partial

import jax
import jax.numpy as jnp

# Import needed to enable JAX cache on Neuron
import jax_neuronx # pylint: disable=unused-import
import neuronxcc.nki.language as nl
from jax import custom_vjp
from neuronxcc.nki.kernels.attention import flash_attn_bwd, flash_fwd

Tensor = jax.Array

lnc = 2 if jax.devices()[0].device_kind == "NC_v3d" else 1


@partial(custom_vjp, nondiff_argnums=(4, 5, 6))
def flash_attention(
query: Tensor,
key: Tensor,
value: Tensor,
bias: Tensor,
causal: bool = False,
softmax_scale: float = 1.0,
dropout_rate: float = 0.0,
):
"""Wraps _mha_forward for custom vjp.
Args:
query: Query of shape [batch_size, target_length, num_heads, per_head_dim].
key: Key of shape [batch_size, source_length, num_heads, per_head_dim].
value: Value of shape [batch_size, source_length, num_heads, per_head_dim].
bias: Optional logit biases of shape [batch_size, num_heads, target_length, source_length].
softmax_scale: Optional scale to apply to softmax. Defaults to 1.
causal: Whether to apply causal mask.
dropout_rate: Dropout rate. Default to 0.0 (no dropout).
Returns:
The attention outputs of shape [batch_size, target_length, num_heads, per_head_dim].
"""
out, _ = _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate)
return out


def _mha_forward(query, key, value, bias, causal, softmax_scale, dropout_rate):
"""Computes attention outputs following FlashAttention.
See also `_mha_backward` for the backward pass.
Args:
query: Input query.
key: Input key.
value: Input value.
bias: Input bias.
causal: Input segment_ids.
softmax_scale: Softmax scale to use in the kernel.
dropout_rate: Dropout rate to use in the kernel.
"""
# Get the batch size, sequence lengths, number of heads, and hidden dimension.
batch_size, _, num_heads, _ = query.shape

# Transpose the query, key, and value tensors.
q = query.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, q_seq_len].
k = key.transpose(0, 2, 3, 1) # [batch_size, num_heads, d_model, kv_seq_len].
v = value.transpose(0, 2, 1, 3) # [batch_size, num_heads, kv_seq_len, d_model].

seed = jnp.array([1])

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads.
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias is not None:
assert (
bias.ndim == 4
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
attn_output, lse = flash_fwd[grid](
q,
k,
v,
seed,
bias,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=dropout_rate,
)
else:
attn_output, lse = flash_fwd[grid](
q,
k,
v,
seed,
use_causal_mask=causal,
softmax_scale=softmax_scale,
mixed_precision=True,
dropout_p=dropout_rate,
)
# Transpose the output back to the original shape.
attn_output = attn_output.transpose(0, 2, 1, 3) # [batch_size, q_seq_len, num_heads, d_model].

return attn_output, (lse, attn_output, q, k, v, bias)


def _mha_backward(causal, softmax_scale, dropout_rate, res, d_attn_output):
lse, o, q, k, v, bias = res
batch_size, num_heads, _, _ = q.shape

# Transpose the input tensors.
o = o.transpose(0, 2, 3, 1)
dy = d_attn_output.transpose(0, 2, 3, 1)

# Transpose v tensor.
v = jnp.transpose(v, axes=(0, 1, 3, 2))
seed = jnp.array([1])

# Call the NKI kernel, duplicate the kernel if we cannot shard on num_heads.
if (num_heads % 2) == 0 and (num_heads // 2 > 0):
grid = batch_size, nl.nc(lnc) * (num_heads // lnc)
else:
grid = batch_size, num_heads

if bias is not None:
assert (
bias.ndim == 4
), f"Neuron flash_attention is only expecting bias.ndim = 4 but got {bias.ndim}"
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
seed,
bias,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=dropout_rate,
softmax_scale=softmax_scale,
)
else:
d_query, d_key, d_value = flash_attn_bwd[grid](
q,
k,
v,
o,
dy,
lse,
seed,
use_causal_mask=causal,
mixed_precision=True,
dropout_p=dropout_rate,
softmax_scale=softmax_scale,
)

# Transpose the gradients back to the original shape.
d_query = d_query.transpose(0, 3, 1, 2)
d_key = d_key.transpose(0, 3, 1, 2)
d_value = d_value.transpose(0, 3, 1, 2)

return d_query, d_key, d_value, None


flash_attention.defvjp(_mha_forward, _mha_backward)
141 changes: 141 additions & 0 deletions axlearn/common/flash_attention/neuron_attention_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright © 2024 Amazon Inc.
"""Tests for Flash attention on Neuron. Tested on trn1 & trn2."""

import chex
import jax
import jax.numpy as jnp
import pytest

from axlearn.common.flash_attention.neuron_attention import flash_attention
from axlearn.common.flash_attention.utils import mha_reference

if jax.default_backend() != "neuron":
pytestmark = pytest.mark.skip(reason="Incompatible hardware, AWS Neuron only test.")


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
[
(1, 2048, 1, 64),
(2, 2048, 2, 64),
(1, 2048, 1, 128),
(2, 2048, 2, 128),
(1, 2048, 8, 128),
(2, 2048, 8, 128),
],
)
@pytest.mark.parametrize("use_fwd", [True, False])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("attention_bias_type", [None, "4d"])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.bfloat16, jnp.float32])
def test_fwd_against_ref(
batch_size: int,
seq_len: int,
num_heads: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
attention_bias_type: bool,
):
softmax_scale = 1.0 / (per_head_dim**0.5)
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)

if attention_bias_type == "4d":
bias = jax.random.normal(k4, (batch_size, num_heads, seq_len, seq_len), dtype=input_dtype)
else:
bias = None

o = flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
)
o_ref = mha_reference(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
)
if input_dtype == jnp.float16:
chex.assert_trees_all_close(o, o_ref, atol=0.07)
elif input_dtype == jnp.float32:
chex.assert_trees_all_close(o, o_ref, atol=0.03)


@pytest.mark.parametrize(
"batch_size,num_heads,seq_len,per_head_dim",
[
(1, 1, 2048, 64),
(2, 2, 2048, 64),
(1, 1, 2048, 128),
(2, 2, 2048, 128),
(1, 8, 2048, 128),
(2, 8, 2048, 128),
],
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.bfloat16, jnp.float16, jnp.float32])
@pytest.mark.parametrize("attention_bias_type", [None, "4d"])
def test_bwd_against_ref(
batch_size: int,
num_heads: int,
seq_len: int,
per_head_dim: int,
causal: bool,
input_dtype: jnp.dtype,
attention_bias_type: bool,
):
softmax_scale = 1.0 / (per_head_dim**0.5)
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(
jax.random.PRNGKey(k1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
k = jax.random.normal(
jax.random.PRNGKey(k2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)
v = jax.random.normal(
jax.random.PRNGKey(k3), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype
)

if attention_bias_type == "4d":
bias = jax.random.normal(k4, (batch_size, num_heads, seq_len, seq_len), dtype=input_dtype)
else:
bias = None
segment_ids = None

def fn(q, k, v, bias):
return flash_attention(
q,
k,
v,
bias,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
).sum()

def ref_fn(q, k, v, bias, segment_ids):
return mha_reference(
q,
k,
v,
bias,
segment_ids,
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=0.0,
).sum()

jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias)
jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids)
chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.07)
26 changes: 24 additions & 2 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention
from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention
from axlearn.common.flash_attention.gpu_decoding import flash_decoding
from axlearn.common.flash_attention.neuron_attention import (
flash_attention as neuron_flash_attention,
)
from axlearn.common.flash_attention.tpu_attention import tpu_flash_attention
from axlearn.common.layers import dropout
from axlearn.common.utils import Tensor
Expand Down Expand Up @@ -106,7 +109,7 @@ def _repeat_kv_heads(num_q_heads: int, key_or_value: Tensor) -> Tensor:


def flash_attention_implementation(
backend: Literal["cpu", "tpu", "gpu", "xla"],
backend: Literal["cpu", "tpu", "gpu", "xla", "neuron"],
*,
softmax_scale: float,
block_size: int = 128,
Expand Down Expand Up @@ -276,13 +279,32 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
block_size=block_size,
)

elif backend == "neuron":
key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)

causal, segment_ids, explicit_bias = split(
bias, CausalAttentionBias, SegmentIdAttentionBias
)

if not isinstance(segment_ids, ZeroAttentionBias):
raise ValueError("Sequence Packing is not supported on Neuron backend")
return neuron_flash_attention(
query,
key,
value,
bias=explicit_bias.value(),
causal=causal.has_value(),
softmax_scale=softmax_scale,
dropout_rate=dropout_rate,
)

elif backend in ("cpu", "xla"):
key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)
if backend == "cpu":
logging.warning("Flash attention CPU backend is for testing only.")
logging.warning("Flash attention falling back using plain MHA implementation")

# `causal` is supported.
# `segment_ids` is supported.
causal, segment_ids, explicit_bias = split(
Expand Down

0 comments on commit f40c4cc

Please sign in to comment.