Skip to content

Commit

Permalink
Add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Aug 22, 2024
1 parent d632eb5 commit da095a6
Showing 1 changed file with 95 additions and 3 deletions.
98 changes: 95 additions & 3 deletions candle-flash-attn/tests/flash_attn_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,23 @@ fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
Ok(t)
}

fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<Tensor> {
fn fa_acausal(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
softcap: Option<f32>,
) -> Result<Tensor> {
let in_dtype = q.dtype();
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
let mut att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
if let Some(softcap) = softcap {
att = (att / softcap as f64)?;
att = att.tanh()?;
att = (att * softcap as f64)?;
}
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Expand All @@ -37,7 +48,7 @@ fn flash_attn_acausal() -> Result<()> {
let v = (&q / 50.)?;
let q = (&q / 30.)?;

let ys1 = fa_acausal(&q, &k, &v, 0.5)?;
let ys1 = fa_acausal(&q, &k, &v, 0.5, None)?;
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
let ys2 = {
let q = q.transpose(1, 2)?;
Expand Down Expand Up @@ -133,3 +144,84 @@ fn flash_attn_varlen() -> Result<()> {
);
Ok(())
}

#[test]
fn flash_attn_acausal_softcap() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 48, &device)?
.to_dtype(DType::F16)?
.reshape((1, 3, 2, 8))?;
let k = (&q / 40.)?;
let v = (&q / 50.)?;
let q = (&q / 30.)?;

let ys1 = fa_acausal(&q, &k, &v, 0.5, Some(30.))?;
let ys1 = ys1.i(0)?.to_dtype(DType::F32)?;
let ys2 = {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
candle_flash_attn::flash_attn_softcap(&q, &k, &v, 0.5, Some(30.), false)?.transpose(1, 2)?
};
let ys2 = ys2.i(0)?.to_dtype(DType::F32)?;
let diff = ys1.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;

assert_eq!(ys1.dims(), &[3, 2, 8]);
assert_eq!(ys2.dims(), &[3, 2, 8]);
assert!(diff.to_vec0::<f32>()?.abs() < 1e-5);
Ok(())
}

#[test]
fn flash_attn_varlen_softcap() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, 48, &device)?
.to_dtype(DType::F16)?
.reshape((3, 2, 8))?;
let k = (&q / 40.)?;
let v = (&q / 50.)?;
let q = (&q / 30.)?;

let seqlens_q = Tensor::new(&[0u32, 2u32], &device)?;
let seqlens_k = Tensor::new(&[0u32, 2u32], &device)?;

let ys = {
let q = q.transpose(0, 1)?;
let k = k.transpose(0, 1)?;
let v = v.transpose(0, 1)?;
candle_flash_attn::flash_attn_varlen_softcap(
&q,
&k,
&v,
&seqlens_q,
&seqlens_k,
32,
32,
0.5,
Some(30.),
false,
)?
.transpose(0, 1)?
};
let ys = ys.to_dtype(DType::F32)?;

assert_eq!(ys.dims(), &[3, 2, 8]);
assert_eq!(
to_vec3_round(ys, 4)?,
&[
[
[0.0837, 0.1038, 0.1238, 0.1438, 0.1637, 0.1837, 0.2037, 0.2238],
[0.0922, 0.1122, 0.1322, 0.1522, 0.1721, 0.1921, 0.2122, 0.2322]
],
[
[0.4204, 0.4404, 0.4604, 0.4805, 0.5005, 0.5205, 0.5405, 0.5605],
[0.428, 0.448, 0.468, 0.488, 0.5078, 0.5278, 0.5479, 0.5679]
],
[
[0.7549, 0.7749, 0.7949, 0.8149, 0.835, 0.855, 0.875, 0.895],
[0.7607, 0.7808, 0.8008, 0.8208, 0.8408, 0.8608, 0.8809, 0.9009]
]
]
);
Ok(())
}

0 comments on commit da095a6

Please sign in to comment.