diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 4ea46b23b..9fa9df332 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -578,6 +578,16 @@ pub fn call_cast_strided( Ok(()) } +#[inline] +fn calculate_reduce_threads(work_per_threadgroup: usize) -> NSUInteger { + let work_split = work_per_threadgroup / 2; + let mut w = 2; + while w < work_split { + w *= 2; + } + w as NSUInteger +} + #[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, @@ -616,17 +626,10 @@ pub fn call_reduce_contiguous( depth: 1, }; - let work_split = work_per_threadgroup / 2; - let mut w = 2; - while w < work_split { - w *= 2; - } - let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - (elements_to_sum as u64).div_ceil(2), - ) - .next_power_of_two(); + calculate_reduce_threads(work_per_threadgroup), + ); let thread_group_size = MTLSize { width, @@ -680,15 +683,9 @@ pub fn call_reduce_strided( depth: 1, }; - let work_split = work_per_threadgroup / 2; - let mut w = 2; - while w < work_split { - w *= 2; - } - let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - w as NSUInteger, + calculate_reduce_threads(work_per_threadgroup), ); let thread_group_size = MTLSize { @@ -734,15 +731,9 @@ pub fn call_last_softmax( depth: 1, }; - let work_split = work_per_threadgroup / 2; - let mut w = 2; - while w < work_split { - w *= 2; - } - let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - w as NSUInteger, + calculate_reduce_threads(work_per_threadgroup), ); let thread_group_size = MTLSize {