diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index cfeb9b86b9..0b0f6dfe1d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -68,13 +68,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -476,13 +476,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion();