From d2f1549abc9422242489ec9069859771c66cff67 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:09:33 -0800 Subject: [PATCH] fix build issue from previous merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index cfeb9b86b9..0b0f6dfe1d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -68,13 +68,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -476,13 +476,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion();