diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1bf4d20cbf..0e2614c5c8 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1879,6 +1879,7 @@ pub fn call_sdpa_vector( k_stride: &[usize], k_buffer: &Buffer, v_offset: usize, + v_stride: &[usize], v_buffer: &Buffer, output: &Buffer, alpha: f32, @@ -1890,7 +1891,8 @@ pub fn call_sdpa_vector( let gqa_factor = (q_shape[1] / k_shape[1]) as i32; let n = k_shape[2] as i32; let b = (q_shape[0] * q_shape[1]) as i32; - let stride = k_stride[1]; + let kstride = k_stride[1]; + let vstride = v_stride[1]; let name = match (bk, itype) { (32, SdpaDType::F16) => "sdpa_vector_float16_t_32", @@ -1949,15 +1951,20 @@ pub fn call_sdpa_vector( encoder.set_bytes( 6, std::mem::size_of::() as u64, - &stride as *const usize as *const c_void, + &kstride as *const usize as *const c_void, ); encoder.set_bytes( 7, + std::mem::size_of::() as u64, + &vstride as *const usize as *const c_void, + ); + encoder.set_bytes( + 8, std::mem::size_of::() as u64, &alpha as *const f32 as *const c_void, ); encoder.set_bytes( - 8, + 9, std::mem::size_of::() as u64, &softcapping as *const f32 as *const c_void, ); diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index d65cc621ac..bfe241d1b0 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -56,6 +56,7 @@ template const constant int& gqa_factor, const constant int& N, const constant size_t& k_stride, + const constant size_t& v_stride, const constant float& scale, const constant float& softcapping, uint3 tid [[threadgroup_position_in_grid]], @@ -82,7 +83,7 @@ template const int kv_head_idx = head_idx / gqa_factor; queries += head_idx * D + simd_lid * elem_per_thread; keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; - values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread; + values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread; out += head_idx * D + simd_gid * elem_per_thread; // Read the query and 0 the output accumulator @@ -1234,6 +1235,7 @@ instantiate_fast_inference_self_attention_kernel(half, half, 16, 16, 256, 2, 2); const constant int& gqa_factor, \ const constant int& N, \ const constant size_t& k_stride, \ + const constant size_t& v_stride, \ const constant float& scale, \ const constant float& softcapping, \ uint3 tid [[threadgroup_position_in_grid]], \ diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index b37c778c88..5e5d83a467 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1094,6 +1094,7 @@ impl candle::CustomOp3 for Sdpa { k_l.stride(), k.buffer(), v_l.start_offset(), + v_l.stride(), v.buffer(), &output, self.scale,