Skip to content

Commit

Permalink
Metal addmm support
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 14, 2024
1 parent 6800496 commit 521e33c
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 14 deletions.
35 changes: 21 additions & 14 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1533,15 +1533,6 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
c_l: &Layout,
) -> Result<()> {
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
DType::BF16 => "bgemm",
dtype => {
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
}
};

let elem_count = b * m * n;

match c_l.contiguous_offsets() {
Expand All @@ -1557,24 +1548,40 @@ impl BackendStorage for MetalStorage {
};

let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm(
command_buffer.set_label("matmul_with_alpha_beta");

let dtype = match self.dtype {
DType::F32 => candle_metal_kernels::GemmDType::F32,
DType::F16 => candle_metal_kernels::GemmDType::F16,
DType::BF16 => candle_metal_kernels::GemmDType::BF16,
dtype => {
return Err(MetalError::Message(format!(
"matmul_with_alpha_beta doesn't support {dtype:?}"
))
.into())
}
};
candle_metal_kernels::call_mlx_addmm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
dtype,
(b, m, n, k),
lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer,
c_l.stride(),
c_l.start_offset() * c.dtype.size_in_bytes(),
&c.buffer,
&c.buffer,
s.unwrap_or(1.) as f32,
1.,
)
.map_err(MetalError::from)?;

Ok(())
}

Expand All @@ -1586,9 +1593,9 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul_with_alpha")?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("matmul");
command_buffer.set_label("matmul_with_alpha");
if self.dtype == DType::BF16 {
if s.unwrap_or(1.) != 1. {
return Err(
Expand Down
1 change: 1 addition & 0 deletions candle-core/tests/matmul_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ fn matmul_alpha_beta(device: &Device) -> Result<()> {
let data = vec![1.0f32, 1.0, 1.0, 1.0];
let mut c = Tensor::from_slice(&data, (2, 2), device)?;

println!("{}", a.matmul(&b)?);
a.matmul_with_alpha_beta(&b, &mut c, Some(2.))?;
assert_eq!(c.to_vec2::<f32>()?, &[[15.0f32, 21.0], [31.0, 45.0]]);
Ok(())
Expand Down
203 changes: 203 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3035,6 +3035,209 @@ pub fn call_mlx_gemm(
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const c_void,
);
encoder.set_bytes(
4,
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const c_void,
);
encoder.set_bytes(
6, // batch_shape
std::mem::size_of::<i32>() as u64,
&(b as i32) as *const i32 as *const c_void,
);
encoder.set_bytes(
7,
(std::mem::size_of::<isize>() * batch_strides.len()) as u64,
batch_strides.as_ptr() as *const c_void,
);

let grid_size = MTLSize {
width: tn as u64,
height: tm as u64,
depth: /* batch_size_out */ b as u64,
};
let group_size = MTLSize {
width: 32,
height: wn,
depth: wm,
};
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size);
Ok(())
}

#[allow(clippy::too_many_arguments)]
pub fn call_mlx_addmm(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GemmDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs_stride: &[usize],
lhs_offset: usize,
lhs_buffer: &Buffer,
rhs_stride: &[usize],
rhs_offset: usize,
rhs_buffer: &Buffer,
c_stride: &[usize],
c_offset: usize,
c_buffer: &Buffer,
output: &Buffer,
alpha: f32,
beta: f32,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct GemmParams {
m: i32,
n: i32,
k: i32,
lda: i32,
ldb: i32,
ldd: i32,
tiles_n: i32,
tiles_m: i32,
batch_stride_a: isize,
batch_stride_b: isize,
batch_stride_d: isize,
swizzle_log: i32,
gemm_k_iterations_aligned: i32,
batch_ndim: i32,
}

#[derive(Debug)]
#[repr(C)]
struct GEMMAddMMParams {
ldc: i32,
fdc: i32,
batch_stride_c: isize,
alpha: f32,
beta: f32
}

assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (lda, a_trans) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, false)
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
// rhs has shape b, k, n
let (ldb, b_trans) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, false)
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, true)
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?;
};
let (bm, bn, bk, wn, wm) = (32, 32, 16, 2, 2);
// https://github.com/ml-explore/mlx/blob/02efb310cac667bc547d1b96f21596c221f84fe7/mlx/backend/metal/matmul.cpp#L422
let constants = Some(ConstantValues::new(vec![
(10, Value::Bool(/* has_batch */ b > 1)),
(100, Value::Bool(/* use_out_source */ true)),
(110, Value::Bool(/* do_axpby */ true)),
(200, Value::Bool(/* align_m */ m % bm == 0)),
(201, Value::Bool(/* align_n */ n % bn == 0)),
(202, Value::Bool(/* align_k */ k % bk == 0)),
(300, Value::Bool(/* do_gather */ false)),
]));

let swizzle_log = 0;
let tile = 1 << swizzle_log;
let tn = n.div_ceil(bn);
let tm = m.div_ceil(bm);
let tn = tn * tile;
let tm = tm.div_ceil(tile);

let batch_stride_a = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3]
} else {
m * k
};
let batch_stride_b = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3]
} else {
n * k
};

let gemm_params = GemmParams {
m: m as i32,
n: n as i32,
k: k as i32,
lda,
ldb,
ldd: n as i32,
tiles_n: tn as i32,
tiles_m: tm as i32,
swizzle_log,
batch_stride_a: batch_stride_a as isize,
batch_stride_b: batch_stride_b as isize,
batch_stride_d: (m * n) as isize,
batch_ndim: 1i32,
gemm_k_iterations_aligned: (k / bk) as i32,
};
let gemm_addmm_params = GEMMAddMMParams {
ldc: c_stride[c_stride.len() - 2] as i32,
fdc: c_stride[c_stride.len() - 1] as i32,
batch_stride_c: (m * n) as isize,
alpha,
beta,
};
let batch_strides = [gemm_params.batch_stride_a, gemm_params.batch_stride_b, gemm_addmm_params.batch_stride_c];

// TODO(laurent): generate the name
// template [[host_name("gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn)]]
let name = match (dtype, a_trans, b_trans) {
(GemmDType::F32, false, false) => "gemm_nn_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, false) => "gemm_tn_f32_f32_32_32_16_2_2",
(GemmDType::F32, false, true) => "gemm_nt_f32_f32_32_32_16_2_2",
(GemmDType::F32, true, true) => "gemm_tt_f32_f32_32_32_16_2_2",
(GemmDType::BF16, false, false) => "gemm_nn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, false) => "gemm_tn_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, false, true) => "gemm_nt_bf16_bf16_32_32_16_2_2",
(GemmDType::BF16, true, true) => "gemm_tt_bf16_bf16_32_32_16_2_2",
(GemmDType::F16, false, false) => "gemm_nn_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, false) => "gemm_tn_f16_f16_32_32_16_2_2",
(GemmDType::F16, false, true) => "gemm_nt_f16_f16_32_32_16_2_2",
(GemmDType::F16, true, true) => "gemm_tt_f16_f16_32_32_16_2_2",
};
let pipeline = kernels.load_pipeline_with_constants(device, Source::Gemm, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(c_buffer), c_offset as NSUInteger);
encoder.set_buffer(3, Some(output), 0);
encoder.set_bytes(
4,
std::mem::size_of::<GemmParams>() as u64,
&gemm_params as *const GemmParams as *const c_void,
);
encoder.set_bytes(
5,
std::mem::size_of::<GEMMAddMMParams>() as u64,
&gemm_addmm_params as *const GEMMAddMMParams as *const c_void,
);
encoder.set_bytes(
6, // batch_shape
std::mem::size_of::<i32>() as u64,
Expand Down

0 comments on commit 521e33c

Please sign in to comment.