Skip to content

Commit

Permalink
Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
Browse files Browse the repository at this point in the history
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

* unpadded lse added
  • Loading branch information
michaelfeil authored Dec 31, 2024
1 parent a594ef6 commit 2a705e6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions candle-flash-attn/kernels/flash_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern "C" void run_mha(

int is_bf16,
int is_causal,
int unpadded_lse,

int window_size_left,
int window_size_right,
Expand Down Expand Up @@ -128,6 +129,7 @@ extern "C" void run_mha(

params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
params.unpadded_lse = unpadded_lse;

cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
Expand Down
1 change: 1 addition & 0 deletions candle-flash-attn/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ extern "C" {

is_bf16: c_int,
is_causal: c_int,
unpadded_lse: c_int,

window_size_left: c_int,
window_size_right: c_int,
Expand Down
8 changes: 4 additions & 4 deletions candle-flash-attn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ impl FlashAttn {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* upadded_lse */ 0,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0f32),
Expand Down Expand Up @@ -518,7 +519,7 @@ impl FlashAttnVarLen {
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
}

let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?;
let expected_kv = (total_k, num_heads_k, head_size_og);
if expected_kv != k_l.shape().dims3()? {
Expand Down Expand Up @@ -601,9 +602,7 @@ impl FlashAttnVarLen {

let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;

let is_bf16 = if is_bf16 { 1 } else { 0 };

Expand Down Expand Up @@ -663,6 +662,7 @@ impl FlashAttnVarLen {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* upadded_lse */ 1,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0.0),
Expand Down

0 comments on commit 2a705e6

Please sign in to comment.