Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix metal sdpa for v stride
Browse files Browse the repository at this point in the history
EricLBuehler committed Nov 4, 2024
1 parent 2d3df4a commit ec1c76e
Showing 3 changed files with 14 additions and 4 deletions.
13 changes: 10 additions & 3 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1879,6 +1879,7 @@ pub fn call_sdpa_vector(
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_stride: &[usize],
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
@@ -1890,7 +1891,8 @@ pub fn call_sdpa_vector(
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let stride = k_stride[1];
let kstride = k_stride[1];
let vstride = v_stride[1];

let name = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
@@ -1949,15 +1951,20 @@ pub fn call_sdpa_vector(
encoder.set_bytes(
6,
std::mem::size_of::<usize>() as u64,
&stride as *const usize as *const c_void,
&kstride as *const usize as *const c_void,
);
encoder.set_bytes(
7,
std::mem::size_of::<usize>() as u64,
&vstride as *const usize as *const c_void,
);
encoder.set_bytes(
8,
std::mem::size_of::<f32>() as u64,
&alpha as *const f32 as *const c_void,
);
encoder.set_bytes(
8,
9,
std::mem::size_of::<f32>() as u64,
&softcapping as *const f32 as *const c_void,
);
4 changes: 3 additions & 1 deletion candle-metal-kernels/src/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ template <typename T, int D>
const constant int& gqa_factor,
const constant int& N,
const constant size_t& k_stride,
const constant size_t& v_stride,
const constant float& scale,
const constant float& softcapping,
uint3 tid [[threadgroup_position_in_grid]],
@@ -82,7 +83,7 @@ template <typename T, int D>
const int kv_head_idx = head_idx / gqa_factor;
queries += head_idx * D + simd_lid * elem_per_thread;
keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;
values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;
out += head_idx * D + simd_gid * elem_per_thread;

// Read the query and 0 the output accumulator
@@ -1234,6 +1235,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2);
const constant int& gqa_factor, \
const constant int& N, \
const constant size_t& k_stride, \
const constant size_t& v_stride, \
const constant float& scale, \
const constant float& softcapping, \
uint3 tid [[threadgroup_position_in_grid]], \
1 change: 1 addition & 0 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
@@ -1094,6 +1094,7 @@ impl candle::CustomOp3 for Sdpa {
k_l.stride(),
k.buffer(),
v_l.start_offset(),
v_l.stride(),
v.buffer(),
&output,
self.scale,

0 comments on commit ec1c76e

Please sign in to comment.