Skip to content

Commit

Permalink
Support broadcasting attn mask in head dim via kernel (#65)
Browse files Browse the repository at this point in the history
* Support broadcasting attn mask in head dim via kernel

* Remove apply_mask_scale

* Metal support

* Fix metal support

* Fix check

* Format
  • Loading branch information
EricLBuehler authored Jan 14, 2025
1 parent 5bed542 commit 3388a3c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 21 deletions.
11 changes: 7 additions & 4 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}

template <typename T>
__device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int ncols, const int nrows_y, const float scale) {
__device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int ncols, const int nrows_y, const int elem_per_batch, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
Expand All @@ -253,7 +253,9 @@ __device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int nc
}

const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;

const int64_t b_idx = elem_per_batch > 0 ? ix / elem_per_batch : 0;
const int64_t iy = (int64_t)b_idx * (ncols*nrows_y) + rowy*ncols + col;

const float val = float(x[ix]) * scale + (mask ? float(mask[iy]) : 0.0f);

Expand Down Expand Up @@ -640,10 +642,11 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
TYPENAME * dst, \
const int ncols, \
const int nrows_y, \
const int elem_per_batch, \
const float scale \
) { \
attn_soft_max<TYPENAME>(x, mask, dst, ncols, nrows_y, scale); \
} \
attn_soft_max<TYPENAME>(x, mask, dst, ncols, nrows_y, elem_per_batch, scale); \
} \

#define RMSNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
Expand Down
10 changes: 10 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ pub fn call_last_attn_softmax(
mask: &Buffer,
mask_offset: usize,
input_shape: &[usize],
mask_shape: &[usize],
scale: f32,
ty: SdpaDType,
output: &Buffer,
Expand All @@ -749,6 +750,14 @@ pub fn call_last_attn_softmax(
let ne02 = input_shape[input_shape.len() - 3] as i64;
let ne03 = input_shape[input_shape.len() - 4] as i64;

let elem_per_batch = if mask_shape.len() == 2 {
0
} else {
let bs = input_shape[0];
let el: usize = input_shape.iter().product();
el / bs
};

let mut nth = 32; // SIMD width
let name = if ne00 % 4 == 0 {
while nth < ne00 / 4 && nth * ne01 * ne02 * ne03 < 256 {
Expand Down Expand Up @@ -784,6 +793,7 @@ pub fn call_last_attn_softmax(
ne00,
ne01,
ne02,
elem_per_batch as i64,
scale
)
);
Expand Down
20 changes: 14 additions & 6 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ kernel void attn_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & elem_per_batch,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
Expand All @@ -838,9 +839,12 @@ kernel void attn_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const T * psrc0 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
device T * pdst = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
const int64_t src_offset = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t b_idx = elem_per_batch > 0 ? src_offset / elem_per_batch : 0;
const int64_t mask_offset = b_idx * (ne00*ne01) + i01*ne00;
device const T * psrc0 = (device const T *) src0 + src_offset;
device const T * pmask = src1 != src0 ? (device const T *) src1 + mask_offset : nullptr;
device T * pdst = (device T *) dst + src_offset;

float slope = 1.0f;

Expand Down Expand Up @@ -916,6 +920,7 @@ kernel void attn_soft_max_4(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & elem_per_batch,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
Expand All @@ -927,9 +932,12 @@ kernel void attn_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const T * psrc4 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
device T * pdst4 = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
const int64_t src_offset = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t b_idx = elem_per_batch > 0 ? src_offset / elem_per_batch : 0;
const int64_t mask_offset = b_idx * (ne00*ne01) + i01*ne00;
device const T * psrc4 = (device const T *) src0 + src_offset / 4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + mask_offset / 4 : nullptr;
device T * pdst4 = (device T *) dst + src_offset / 4;

float slope = 1.0f;

Expand Down
65 changes: 54 additions & 11 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,14 +613,17 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {
if a_l.dims().len() != 4 {
candle::bail!("attn-softmax-last-dim expects xs of rank 2");
}
if mask_l.dims().len() != 2 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2");
if mask_l.dims().len() != 2 && mask_l.dims().len() != 3 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2 or 3");
}
if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)?
|| mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)?
{
candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims");
}
if mask_l.dims().len() == 3 && mask_l.dim(0)? != a_l.dim(0)? {
candle::bail!("attn-softmax-last-dim expects rank-3 mask bs to match xs bs");
}

candle_metal_kernels::call_last_attn_softmax(
device.metal_device(),
Expand All @@ -631,6 +634,7 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {
mask_s.buffer(),
mask_l.start_offset() * mask_s.dtype().size_in_bytes(),
a_l.dims(),
mask_l.dims(),
self.scale,
ty,
&a_s.buffer(),
Expand Down Expand Up @@ -668,14 +672,17 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {
if a_l.dims().len() != 4 {
candle::bail!("attn-softmax-last-dim expects xs of rank 2");
}
if mask_l.dims().len() != 2 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2");
if mask_l.dims().len() != 2 && mask_l.dims().len() != 3 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2 or 3");
}
if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)?
|| mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)?
{
candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims");
}
if mask_l.dims().len() == 3 && mask_l.dim(0)? != a_l.dim(0)? {
candle::bail!("attn-softmax-last-dim expects rank-3 mask bs to match xs bs");
}

struct S<'a> {
scale: f32,
Expand Down Expand Up @@ -703,6 +710,13 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {
let dims = self.a_l.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let nrows_y = dims[dims.len() - 2];
let elem_per_batch = if mask_l.dims().len() == 2 {
0
} else {
let bs = dims[0];
el / bs
};

let (nrows_x, ncols_x) = (el / dim_m1, dim_m1);

const WARP_SIZE: usize = 32;
Expand All @@ -719,7 +733,15 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {
};
let func =
dev.get_or_load_func(&kernel_name::<T>("attn_soft_max"), kernels::REDUCE)?;
let params = (&a, &mask, &a, ncols_x as i32, nrows_y as i32, self.scale);
let params = (
&a,
&mask,
&a,
ncols_x as i32,
nrows_y as i32,
elem_per_batch as i32,
self.scale,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;

Expand Down Expand Up @@ -783,14 +805,17 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
if a_l.dims().len() != 4 {
candle::bail!("attn-softmax-last-dim expects xs of rank 2");
}
if mask_l.dims().len() != 2 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2");
if mask_l.dims().len() != 2 && mask_l.dims().len() != 3 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2 or 3");
}
if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)?
|| mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)?
{
candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims");
}
if mask_l.dims().len() == 3 && mask_l.dim(0)? != a_l.dim(0)? {
candle::bail!("attn-softmax-last-dim expects rank-3 mask bs to match xs bs");
}

let elem_count = a_l.shape().elem_count();
let output = device.new_buffer(elem_count, a_s.dtype(), "attn-softmax")?;
Expand All @@ -803,6 +828,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
mask_s.buffer(),
mask_l.start_offset() * mask_s.dtype().size_in_bytes(),
a_l.dims(),
mask_l.dims(),
self.scale,
ty,
&output,
Expand Down Expand Up @@ -840,14 +866,17 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
if a_l.dims().len() != 4 {
candle::bail!("attn-softmax-last-dim expects xs of rank 2");
}
if mask_l.dims().len() != 2 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2");
if mask_l.dims().len() != 2 && mask_l.dims().len() != 3 {
candle::bail!("attn-softmax-last-dim expects mask of rank 2 or 3");
}
if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)?
|| mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)?
{
candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims");
}
if mask_l.dims().len() == 3 && mask_l.dim(0)? != a_l.dim(0)? {
candle::bail!("attn-softmax-last-dim expects rank-3 mask bs to match xs bs");
}

struct S {
scale: f32,
Expand All @@ -874,6 +903,12 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
let dims = a_l.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let nrows_y = dims[dims.len() - 2];
let elem_per_batch = if mask_l.dims().len() == 2 {
0
} else {
let bs = dims[0];
el / bs
};
let (nrows_x, ncols_x) = (el / dim_m1, dim_m1);

const WARP_SIZE: usize = 32;
Expand All @@ -892,7 +927,15 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
dev.get_or_load_func(&kernel_name::<T>("attn_soft_max"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&a, &mask, &dst, ncols_x as i32, nrows_y as i32, self.scale);
let params = (
&a,
&mask,
&dst,
ncols_x as i32,
nrows_y as i32,
elem_per_batch as i32,
self.scale,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;

Expand All @@ -917,7 +960,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
/// candle_nn::ops::softmax_last_dim(&(xs.broadcast_add(&mask)? * scale as f64)?)?
/// ```
/// - `xs` must be a rank-4 tensor
/// - `mask` must be a rank-2 matrix
/// - `mask` must be a rank-2 matrix or a rank 3 matrix
/// - The last 2 dimensions of `xs` must match the dimensions of `mask`.
///
/// Note: if the last dim of `xs` is a multiple of 4, a vectorized implementation will be used.
Expand Down

0 comments on commit 3388a3c

Please sign in to comment.