From 5d509773fa914da5ebb21aaf0ef96fedcbb0fc10 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Wed, 8 Jan 2025 15:29:40 -0800 Subject: [PATCH] Enable cudnn dropout --- axlearn/common/flash_attention/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index e9b92e06..0cc30657 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -224,7 +224,6 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: or explicit_bias.has_value() or jnp.float32 in (query.dtype, key.dtype, value.dtype) or query.shape[1] != key.shape[1] - or dropout_rate != 0.0 ): logging.warning("Flash attention falling back to Triton GPU kernel.") return gpu_flash_attention( @@ -248,7 +247,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: bias=explicit_bias.value(), softmax_scale=softmax_scale, causal=causal.has_value(), - dropout_rate=0.0, + dropout_rate=dropout_rate, ) elif backend == "tpu":