Skip to content

Commit

Permalink
Enable cudnn dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Jan 28, 2025
1 parent b125f00 commit 5d50977
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
or explicit_bias.has_value()
or jnp.float32 in (query.dtype, key.dtype, value.dtype)
or query.shape[1] != key.shape[1]
or dropout_rate != 0.0
):
logging.warning("Flash attention falling back to Triton GPU kernel.")
return gpu_flash_attention(
Expand All @@ -248,7 +247,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
bias=explicit_bias.value(),
softmax_scale=softmax_scale,
causal=causal.has_value(),
dropout_rate=0.0,
dropout_rate=dropout_rate,
)

elif backend == "tpu":
Expand Down

0 comments on commit 5d50977

Please sign in to comment.