Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Metal support
Browse files Browse the repository at this point in the history
EricLBuehler committed Jan 14, 2025
1 parent 072c715 commit a6639d0
Showing 4 changed files with 29 additions and 9 deletions.
10 changes: 10 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -738,6 +738,7 @@ pub fn call_last_attn_softmax(
mask: &Buffer,
mask_offset: usize,
input_shape: &[usize],
mask_shape: &[usize],
scale: f32,
ty: SdpaDType,
output: &Buffer,
@@ -749,6 +750,14 @@ pub fn call_last_attn_softmax(
let ne02 = input_shape[input_shape.len() - 3] as i64;
let ne03 = input_shape[input_shape.len() - 4] as i64;

let elem_per_batch = if mask_shape.len() == 2 {
0
} else {
let bs = input_shape[0];
let el: usize = input_shape.iter().product();
el / bs
};

let mut nth = 32; // SIMD width
let name = if ne00 % 4 == 0 {
while nth < ne00 / 4 && nth * ne01 * ne02 * ne03 < 256 {
@@ -784,6 +793,7 @@ pub fn call_last_attn_softmax(
ne00,
ne01,
ne02,
elem_per_batch as i64,
scale
)
);
20 changes: 14 additions & 6 deletions candle-metal-kernels/src/reduce.metal
Original file line number Diff line number Diff line change
@@ -827,6 +827,7 @@ kernel void attn_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & elem_per_batch,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
@@ -838,9 +839,12 @@ kernel void attn_soft_max(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const T * psrc0 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
device T * pdst = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
const int64_t src_offset = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t b_idx = elem_per_batch > 0 ? src_offset / elem_per_batch : 0;
const int64_t mask_offset = b_idx * (ne00*ne01) + i01*ne00;
device const T * psrc0 = (device const T *) src0 + src_offset;
device const T * pmask = src1 != src0 ? (device const T *) src1 + mask_offset : nullptr;
device T * pdst = (device T *) dst + src_offset;

float slope = 1.0f;

@@ -916,6 +920,7 @@ kernel void attn_soft_max_4(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & elem_per_batch,
constant float & scale,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
@@ -927,9 +932,12 @@ kernel void attn_soft_max_4(
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

device const T * psrc4 = (device const T *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
device T * pdst4 = (device T *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
const int64_t src_offset = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t b_idx = elem_per_batch > 0 ? src_offset / elem_per_batch : 0;
const int64_t mask_offset = b_idx * (ne00*ne01) + i01*ne00;
device const T * psrc0 = (device const T *) src0 + src_offset / 4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + mask_offset / 4 : nullptr;
device T * pdst = (device T *) dst + src_offset / 4;

float slope = 1.0f;

6 changes: 4 additions & 2 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
@@ -634,6 +634,7 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {
mask_s.buffer(),
mask_l.start_offset() * mask_s.dtype().size_in_bytes(),
a_l.dims(),
mask_l.dims(),
self.scale,
ty,
&a_s.buffer(),
@@ -713,7 +714,7 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim {
0
} else {
let bs = dims[0];
el / bs;
el / bs
};

let (nrows_x, ncols_x) = (el / dim_m1, dim_m1);
@@ -827,6 +828,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
mask_s.buffer(),
mask_l.start_offset() * mask_s.dtype().size_in_bytes(),
a_l.dims(),
mask_l.dims(),
self.scale,
ty,
&output,
@@ -905,7 +907,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim {
0
} else {
let bs = dims[0];
el / bs;
el / bs
};
let (nrows_x, ncols_x) = (el / dim_m1, dim_m1);

2 changes: 1 addition & 1 deletion candle-nn/tests/ops.rs
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use candle::{test_device, test_utils::to_vec3_round, DType, Device, IndexOp, Result, Tensor};
use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};

fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];

0 comments on commit a6639d0

Please sign in to comment.