Skip to content

Commit

Permalink
Add calculate_reduce_threads util
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Jan 13, 2025
1 parent 5c59669 commit bcbbad6
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,16 @@ pub fn call_cast_strided(
Ok(())
}

#[inline]
fn calculate_reduce_threads(work_per_threadgroup: usize) -> NSUInteger {
let work_split = work_per_threadgroup / 2;
let mut w = 2;
while w < work_split {
w *= 2;
}
w as NSUInteger
}

#[allow(clippy::too_many_arguments)]
pub fn call_reduce_contiguous(
device: &Device,
Expand Down Expand Up @@ -616,17 +626,10 @@ pub fn call_reduce_contiguous(
depth: 1,
};

let work_split = work_per_threadgroup / 2;
let mut w = 2;
while w < work_split {
w *= 2;
}

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
(elements_to_sum as u64).div_ceil(2),
)
.next_power_of_two();
calculate_reduce_threads(work_per_threadgroup),
);

let thread_group_size = MTLSize {
width,
Expand Down Expand Up @@ -680,15 +683,9 @@ pub fn call_reduce_strided(
depth: 1,
};

let work_split = work_per_threadgroup / 2;
let mut w = 2;
while w < work_split {
w *= 2;
}

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
w as NSUInteger,
calculate_reduce_threads(work_per_threadgroup),
);

let thread_group_size = MTLSize {
Expand Down Expand Up @@ -734,15 +731,9 @@ pub fn call_last_softmax(
depth: 1,
};

let work_split = work_per_threadgroup / 2;
let mut w = 2;
while w < work_split {
w *= 2;
}

let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
w as NSUInteger,
calculate_reduce_threads(work_per_threadgroup),
);

let thread_group_size = MTLSize {
Expand Down

0 comments on commit bcbbad6

Please sign in to comment.