Skip to content

Commit

Permalink
WIP: add support for SWA (left,0) + THD/BSHD/SBHD + padding + CM/BRCM
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Dec 12, 2024
1 parent e4c99b0 commit 165f99c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 47 deletions.
82 changes: 46 additions & 36 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ def test():
return available_backends, fused_attention_backend

backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
#with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends


Expand Down Expand Up @@ -258,6 +258,8 @@ def test_dot_product_attention(
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
unfused_attn_supported = False
print(flash_attn_supported, fused_attn_supported, unfused_attn_supported)
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if pad_between_seqs and not (
Expand Down Expand Up @@ -531,18 +533,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model):

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
#"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
#"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
#"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
#"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
#"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
#"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
#"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
#"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
#"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
#"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
#"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
#"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_0": ModelConfig(4, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_4_2": ModelConfig(4, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"),
"swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"),
}


Expand All @@ -552,7 +558,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)
test_dot_product_attention(dtype, model_configs, model, False, True, 'bshd_bshd_bshd', True, False)


model_configs_alibi_slopes = {
Expand Down Expand Up @@ -619,18 +625,22 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
#"layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
#"layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
#"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
#"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
#"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
#"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
#"layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
#"layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
#"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
#"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
#"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
#"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_3_0": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4,0)),
"layout_3_1": ModelConfig(4, 16, 1, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4,0)),
"layout_3_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4,0)),
"layout_3_3": ModelConfig(4, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4,0)),
}


Expand All @@ -647,10 +657,10 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
config = model_configs[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
#pad_between_seqs = True
#test_dot_product_attention(
# dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
#)
if get_cudnn_version() >= (9, 3, 0):
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
Expand Down
16 changes: 14 additions & 2 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0) ||
((cudnn_runtime_version >= 90500) &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
Expand All @@ -197,7 +202,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q == max_seqlen_kv)) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) &&
qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) ||
(cudnn_runtime_version >= 90500 &&
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))) &&
// check 64-bit ragged offset support
(supported_ragged_offset_size)) {
flag_arb = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_dropout = (is_training && dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion();
Expand Down Expand Up @@ -451,7 +452,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion();
Expand Down
14 changes: 7 additions & 7 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,13 +715,13 @@ def get_attention_backend(
attn_mask_type,
)
use_fused_attention = False
elif "padding" in attn_mask_type:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with attn_mask_type = %s",
attn_mask_type,
)
use_fused_attention = False
# elif "padding" in attn_mask_type:
# logger.debug(
# "Disabling FusedAttention as it does not support sliding window attention "
# "with attn_mask_type = %s",
# attn_mask_type,
# )
# use_fused_attention = False
if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if _use_flash_attn_3:
logger.debug(
Expand Down

0 comments on commit 165f99c

Please sign in to comment.