Skip to content

Commit

Permalink
Flash-Attn upgrade / SoftCap Candle-FlashAttn [2/n] (#2689)
Browse files Browse the repository at this point in the history
* update flash-attn v1

* restore: hdim224

* add 224 flash_fwd_template

* remove whitespace

* softcap is working, including test and api.

* make softcap test case better

---------

Co-authored-by: laurent <[email protected]>
  • Loading branch information
michaelfeil and LaurentMazare authored Dec 31, 2024
1 parent 71cd6d5 commit a594ef6
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 3 deletions.
16 changes: 13 additions & 3 deletions candle-flash-attn/kernels/flash_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ extern "C" void run_mha(
int is_causal,

int window_size_left,
int window_size_right
int window_size_right,

float softcap
) {
Flash_fwd_params params;
// Reset the parameters
Expand Down Expand Up @@ -99,8 +101,16 @@ extern "C" void run_mha(
params.d_rounded = d_rounded;

// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}

params.p_dropout = 1.; // probability to keep
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
Expand Down
2 changes: 2 additions & 0 deletions candle-flash-attn/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ extern "C" {

window_size_left: c_int,
window_size_right: c_int,

softcap: f32,
);

}
115 changes: 115 additions & 0 deletions candle-flash-attn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct FlashAttn {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
pub softcap: Option<f32>,
}

fn round_multiple(x: usize, m: usize) -> usize {
Expand Down Expand Up @@ -201,6 +202,7 @@ impl FlashAttn {
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0f32),
)
}

Expand Down Expand Up @@ -271,6 +273,7 @@ pub fn flash_attn(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
Expand Down Expand Up @@ -308,6 +311,7 @@ pub fn flash_attn_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
Expand Down Expand Up @@ -342,6 +346,7 @@ pub fn flash_attn_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
Expand Down Expand Up @@ -381,6 +386,52 @@ pub fn flash_attn_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}

/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors `k` and `v` with fewer heads
/// than `q`. The number of heads in `k` and `v` must be divisible by the number of heads in `q`.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Optional alibi slopes tensor with shape `(num_heads_q)`.
/// * `softmax_scale` - Scaling factor for the softmax operation.
/// * `window_size_left` - Optional limit on left attention to value tokens.
/// * `window_size_right` - Optional limit on right attention to value tokens.
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
///
/// # Causal Mask
///
/// Setting `window_size_left=None` and `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`.
///
/// # Returns
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi_windowed_softcap(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: Option<&Tensor>,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
softcap: f32,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: alibi_slopes.cloned(),
window_size_left,
window_size_right,
softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}
Expand All @@ -394,6 +445,7 @@ struct FlashAttnVarLen {
pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
pub softcap: Option<f32>,
}

impl FlashAttnVarLen {
Expand Down Expand Up @@ -613,6 +665,7 @@ impl FlashAttnVarLen {
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0.0),
)
}

Expand Down Expand Up @@ -699,6 +752,7 @@ pub fn flash_attn_varlen(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
Expand Down Expand Up @@ -752,6 +806,7 @@ pub fn flash_attn_varlen_windowed(
alibi_slopes: None,
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
Expand Down Expand Up @@ -802,6 +857,7 @@ pub fn flash_attn_varlen_alibi(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}
Expand Down Expand Up @@ -857,6 +913,65 @@ pub fn flash_attn_varlen_alibi_windowed(
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
softcap: None,
};
q.apply_op3(k, v, op)
}

#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Option, alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Option, limit left attention to value tokens.
/// * `window_size_right` - Option, limit right attention to value tokens.
/// * `softcap` - Gemma style softcap the attention logits before the softmax.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_alibi_windowed_softcap(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: Option<&Tensor>,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
softcap: f32,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: alibi_slopes.cloned(),
window_size_left,
window_size_right,
softcap: Some(softcap),
};
q.apply_op3(k, v, op)
}
52 changes: 52 additions & 0 deletions candle-flash-attn/tests/flash_attn_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
Ok(output)
}

fn fa_acausal_softcap(q: &Tensor, k: &Tensor, v: &Tensor, softcap: f32) -> Result<Tensor> {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
// let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
let att = q.matmul(&k.t()?)?;
let att = (softcap as f64 * ((att / softcap as f64)?.tanh())?)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Ok(output)
}

#[test]
fn flash_attn_acausal() -> Result<()> {
let device = Device::new_cuda(0)?;
Expand Down Expand Up @@ -89,6 +103,44 @@ fn flash_attn_acausal() -> Result<()> {
Ok(())
}

#[test]
fn flash_attn_acausal_softcap() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 3 * 5 * 8, &device)?
.to_dtype(DType::F16)?
.reshape((1, 3, 5, 8))?;
let k = (&q / 40.)?;
let v = (&q / 50.)?;
let q = (&q / 30.)?;
let softcap = 5.0f32;

let ys1 = fa_acausal_softcap(&q, &k, &v, softcap.clone())?;
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
let ys2 = {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
candle_flash_attn::flash_attn_alibi_windowed_softcap(
&q,
&k,
&v,
None, // alibi_slopes //
1.0, // softmax //
None, // window_size_left //
None, // window_size_right //
softcap.clone(), // softcap //
)?
.transpose(1, 2)?
};
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;

assert_eq!(ys1.dims(), &[3, 5, 8]);
assert_eq!(ys2.dims(), &[3, 5, 8]);
assert!(diff.to_vec0::<f32>()?.abs() < 1e-3);
Ok(())
}

#[test]
fn flash_attn_varlen() -> Result<()> {
let device = Device::new_cuda(0)?;
Expand Down

0 comments on commit a594ef6

Please sign in to comment.